Compare commits

...

86 Commits

Author SHA1 Message Date
Ariel Rin
297f98f046 Version Bump 2.15.1 2022-07-19 19:07:47 +10:00
Ariel Rin
27dad05927 Merge branch 'fix-discord-update-username' into 'master'
Fix discord update username

See merge request allianceauth/allianceauth!1442
2022-07-19 08:48:26 +00:00
Erik Kalkoken
697e9dd772 Fix discord update username 2022-07-19 08:48:25 +00:00
Ariel Rin
65f2efc890 Version Bump 2.15.0 2022-07-18 19:22:50 +10:00
Ariel Rin
def30900b4 Merge branch 'discord_bugfixes_and_refactor' into 'master'
Fix managed roles and reserved groups bugs in Discord Service and more

Closes #1345 and #1334

See merge request allianceauth/allianceauth!1429
2022-07-18 09:12:32 +00:00
Erik Kalkoken
d7fabccddd Fix managed roles and reserved groups bugs in Discord Service and more 2022-07-18 09:12:32 +00:00
Ariel Rin
45289e1d17 Merge branch 'fix-filterdropdown-bug' into 'master'
Fix filterdropdown bug

See merge request allianceauth/allianceauth!1439
2022-07-18 09:04:35 +00:00
ErikKalkoken
7b9bf08aa3 Fix bug in filterDropDown bundle 2022-07-15 13:39:48 +02:00
Ariel Rin
def6431052 Version Bump 2.14.0 2022-07-11 14:27:49 +10:00
Ariel Rin
22a270aedb Merge branch 'filterdropdown-backwards-compatibility' into 'master'
Add filterdropdown bundle to AA2 to ensure backwards compatibility

See merge request allianceauth/allianceauth!1437
2022-07-11 04:15:25 +00:00
Peter Pfeufer
c930f7bbeb Also adds timers.js, eve-time.js and refresh_notifications.js
As these seem to be used in some apps as well
2022-07-09 15:57:43 +02:00
Peter Pfeufer
64ee273953 Add filterdropdown bundle to AA2 to ensure backwards compatibility 2022-07-09 13:43:05 +02:00
Ariel Rin
3706a1aedf Merge branch 'improve-autodocs-for-models' into 'master'
Improve autodocs for models & more

See merge request allianceauth/allianceauth!1435
2022-07-07 07:38:58 +00:00
Ariel Rin
47f1b77320 Merge branch 'consolidate-redis-client-access' into 'master'
Ensure backwards compatibility when fetching a redis client

See merge request allianceauth/allianceauth!1428
2022-07-07 07:37:21 +00:00
Erik Kalkoken
8dec242a93 Ensure backwards compatibility when fetching a redis client 2022-07-07 07:37:21 +00:00
ErikKalkoken
2ff200c566 Refer to django-esi docs 2022-06-27 13:43:45 +02:00
ErikKalkoken
091a2637ea Add extension to improve autodocs for Django models & enable source links 2022-06-27 13:41:15 +02:00
Ariel Rin
113977b19f Version Bump 2.13.0 2022-06-18 13:07:36 +10:00
Ariel Rin
8f39b50b6d Merge branch 'Maestro-Zacht-fix-fat-attributeerror' into 'master'
fixed attribute error

See merge request allianceauth/allianceauth!1421
2022-06-18 02:53:11 +00:00
Maestro-Zacht
95b309c358 fixed attribute error 2022-06-18 02:53:11 +00:00
Ariel Rin
cf3df3b715 Merge branch 'fix_issue_1328' into 'master'
Fix: Changing group's state setting does not kick existing non-conforming group members

Closes #1328

See merge request allianceauth/allianceauth!1400
2022-06-18 02:47:14 +00:00
Erik Kalkoken
d815028c4d Fix: Changing group's state setting does not kick existing non-conforming group members 2022-06-18 02:47:14 +00:00
Ariel Rin
ac5570abe2 Merge branch 'fix_issue_1268' into 'master'
Fix: Service group updates broken when adding users to groups

Closes #1268

See merge request allianceauth/allianceauth!1403
2022-06-18 02:41:23 +00:00
Erik Kalkoken
84ad571aa4 Fix: Service group updates broken when adding users to groups 2022-06-18 02:41:23 +00:00
Ariel Rin
38e7705ae7 Merge branch 'docs-dark-mode' into 'master'
Add automatic dark mode to docs

See merge request allianceauth/allianceauth!1427
2022-06-18 02:39:59 +00:00
ErikKalkoken
0b6af014fa Add automatic dark mode to docs 2022-06-17 21:49:18 +02:00
Ariel Rin
2401f2299d Merge branch 'fix-doc-redis-issue' into 'master'
Fix: Broken docs generation on readthedocs.org (2nd attempt)

See merge request allianceauth/allianceauth!1425
2022-06-17 11:58:45 +00:00
Erik Kalkoken
919768c8bb Fix: Broken docs generation on readthedocs.org (2nd attempt) 2022-06-17 11:58:45 +00:00
Ariel Rin
24db21463b Merge branch 'docs-template-tags-example' into 'master'
Add example for template tags to docs

See merge request allianceauth/allianceauth!1426
2022-06-17 11:58:05 +00:00
Erik Kalkoken
1e029af83a Add example for template tags to docs 2022-06-17 11:58:05 +00:00
Ariel Rin
2b31be789d Merge branch 'fix-issue-1336' into 'master'
Fix: Broken docs generation on readthedocs.org

Closes #1336

See merge request allianceauth/allianceauth!1423
2022-06-06 10:48:16 +00:00
Erik Kalkoken
bf1b4bb549 Fix: Broken docs generation on readthedocs.org 2022-06-06 10:48:16 +00:00
Ariel Rin
dd42b807f0 Version Bump 2.12.1 2022-05-13 00:19:45 +10:00
Ariel Rin
542fbafd98 Merge branch 'cherry-pick-4836559a' into 'v2.12.x'
Merge branch 'fix-decimal_widthratio-template-tag' into 'v2.12.x'

See merge request allianceauth/allianceauth!1420
2022-05-12 14:14:01 +00:00
Ariel Rin
37b9f5c882 Merge branch 'fix-decimal_widthratio-template-tag' into 'v3.x'
[FIX] Division by zero in decimal_widthratio template tag

See merge request allianceauth/allianceauth!1419

(cherry picked from commit 4836559abe)

8dd07b97 [FIX] Devision by zero in decimal_widthratio template tag
17b06c88 Make it a string in accordance to the return value type
2022-05-12 13:33:45 +00:00
Ariel Rin
5bde9a6952 Version Bump 2.12.0 2022-05-12 18:54:22 +10:00
Ariel Rin
23ad9d02d3 Merge branch 'cherry-pick-7fa76d6d' into 'v2.11.x'
Update GitLab CI to conform with the changes to artifacts collection, 2.11.x backport

See merge request allianceauth/allianceauth!1418
2022-05-12 04:30:07 +00:00
Ariel Rin
f99878cc29 Update .gitlab-ci.yml 2022-05-12 04:07:43 +00:00
Ariel Rin
e64431b06c Merge branch 'update-gitlab-ci' into 'v3.x'
Update GitLab CI to conform with the changes to artifacts collection

See merge request allianceauth/allianceauth!1417

(cherry picked from commit 7fa76d6d37)

a3cce358 Update GitLab CI to conform with the changes to artifacts collection
2022-05-12 04:06:04 +00:00
Ariel Rin
0b2993c1c3 Merge branch 'improve_notifications_2' into 'v2.11.x'
Improve notifications

See merge request allianceauth/allianceauth!1411
2022-05-12 04:02:17 +00:00
Erik Kalkoken
75bccf1b0f Improve notifications 2022-05-12 04:02:17 +00:00
Ariel Rin
945bc92898 Merge branch 'admin-dash-improvement' into 'v2.11.x'
Improve Admin Celery Bar

See merge request allianceauth/allianceauth!1414
2022-05-12 03:57:02 +00:00
Ariel Rin
ec7d14a839 Merge branch 'fix_issue_1222' into 'v2.11.x'
Close security loopholes to make non-superuser admins usable

See merge request allianceauth/allianceauth!1413
2022-05-12 03:56:22 +00:00
Erik Kalkoken
dd1a368ff6 Close security loopholes to make non-superuser admins usable 2022-05-12 03:56:22 +00:00
colcrunch
54085617dc Add a few pixels of margin-top to bar labels to better center them. 2022-04-16 15:46:01 -04:00
colcrunch
8cdc5af453 Improve celery bar by using decimalized width values (2 decimal places) to reduce likelyhood of an empty portion of the bar. 2022-04-16 15:44:53 -04:00
Ariel Rin
da93940e13 Just an empty Tag Commit, because 2.11.2 bump went wonky 2022-03-29 14:48:39 +10:00
Ariel Rin
f53b43d9dc Merge branch 'master' of https://gitlab.com/allianceauth/allianceauth into v2.11.x 2022-03-29 14:47:40 +10:00
Ariel Rin
497a167ca7 Version Bump v2.11.2 2022-03-29 14:46:59 +10:00
Ariel Rin
852c5a3037 Bump Django-ESI to 4.x, inc breaking CCP change in 4.0.1 2022-03-29 14:40:30 +10:00
Ariel Rin
90f6777a7a Version Bump 2.11.1 2022-03-20 14:42:39 +10:00
Ariel Rin
a8d890abaf Merge branch 'improve_task_statistics' into 'master'
Improve task statistics

See merge request allianceauth/allianceauth!1409
2022-03-09 10:04:14 +00:00
Erik Kalkoken
79379b444c Improve task statistics 2022-03-09 10:04:13 +00:00
Ariel Rin
ace1de5c68 Merge branch 'fix-docker-new-redis' into 'master'
Fix docker for new redis

See merge request allianceauth/allianceauth!1406
2022-03-09 10:02:01 +00:00
Kevin McKernan
5d6128e9ea remove collectstatic command from dockerfile 2022-03-01 13:23:49 -07:00
Ariel Rin
131cc5ed0a Version Bump 2.11.0 2022-02-26 17:26:55 +10:00
Ariel Rin
9297bed43f Version Bump 2.10.2 2022-02-26 16:37:20 +10:00
Ariel Rin
b2fddc683a Merge branch 'master' of https://gitlab.com/allianceauth/allianceauth into v2.10.x 2022-02-26 16:32:45 +10:00
Ariel Rin
9af634d16a Merge branch 'fix_show_available_groups_for_user_only' into 'master'
Fix: Users can be assigned to groups depite not matching state restrictions

See merge request allianceauth/allianceauth!1402
2022-02-26 05:19:45 +00:00
Erik Kalkoken
a68163caa3 Fix: Users can be assigned to groups depite not matching state restrictions 2022-02-26 05:19:45 +00:00
Ariel Rin
00770fd034 Merge branch 'improve_celery_info_on_dashboard' into 'master'
Improve celery infos on Dashboard

See merge request allianceauth/allianceauth!1384
2022-02-26 05:15:30 +00:00
Erik Kalkoken
01164777ed Improve celery infos on Dashboard 2022-02-26 05:15:30 +00:00
Ariel Rin
00f5e3e1e0 Version Bump 2.10.1 2022-02-21 00:02:12 +10:00
Ariel Rin
8b2527f408 Merge branch 'capsleekxmpp' into 'master'
Cap sleekxmpp to 1.3.2

See merge request allianceauth/allianceauth!1401
2022-02-20 13:44:27 +00:00
Ariel Rin
b7500e4e4e Cap sleekxmpp to 1.3.2 2022-02-20 13:44:27 +00:00
Kevin McKernan
4f4bd0c419 add note to docker README about Apple M1 support 2022-02-20 23:41:12 +10:00
Ariel Rin
8ae4e02012 Merge branch 'docker-bump-version' into 'v2.10.x'
Bump version for Docker deployment to v2.10.x.

See merge request allianceauth/allianceauth!1396
2022-02-02 13:26:33 +00:00
Weyland
cc9a07197d Bump version for Docker deployment to v2.10.x. 2022-02-02 13:30:05 +01:00
Ariel Rin
f18dd1029b Version Bump v2.10.0 2022-01-31 20:58:09 +10:00
Ariel Rin
fd8d43571a Merge branch 'analytics' into 'master'
Analytics - Extra Ignore Path

See merge request allianceauth/allianceauth!1347
2022-01-31 09:23:43 +00:00
Ariel Rin
13e88492f1 Analytics - Extra Ignore Path 2022-01-31 09:23:43 +00:00
Ariel Rin
38df580a56 Merge branch 'analytics_update' into 'master'
Add setting to disable analytics

See merge request allianceauth/allianceauth!1373
2022-01-27 05:14:12 +00:00
Erik Kalkoken
ba39318313 Add setting to disable analytics 2022-01-27 05:14:11 +00:00
Ariel Rin
d8c6035405 Merge branch 'ts3_reserved_groups' into 'master'
Implement reserved group names in Teamspeak3 service module.

See merge request allianceauth/allianceauth!1380
2022-01-27 05:10:22 +00:00
Ariel Rin
2ef3da916b Merge branch 'datatablessavestate' into 'master'
Add DataTables stateSave feature

See merge request allianceauth/allianceauth!1374
2022-01-27 05:05:37 +00:00
Ariel Rin
d32d8b26ce Merge branch 'delete_characters' into 'master'
Fix: Can not update biomassed characters

See merge request allianceauth/allianceauth!1381
2022-01-27 05:02:57 +00:00
Erik Kalkoken
f348b1a34c Fix: Can not update biomassed characters 2022-01-27 05:02:57 +00:00
Ariel Rin
86aaa3edda Merge branch 'fix-grafana-image-2' into 'master'
fix grafana image again, thanks grafana for not tagging your new images properly

See merge request allianceauth/allianceauth!1393
2022-01-27 04:57:40 +00:00
Ariel Rin
26017056c7 Merge branch 'evetime-js-update' into 'master'
Evetime js update

See merge request allianceauth/allianceauth!1395
2022-01-27 04:35:15 +00:00
Peter Pfeufer
e39a3c072b Evetime js update 2022-01-27 04:35:15 +00:00
Kevin McKernan
827291dda4 fix grafana image again, thanks grafana for not tagging your new images properly 2022-01-07 10:48:50 -07:00
Adarnof
8de2c3bfcb Update name of serverquery IP file changed in TS3 v3.13.0
Changelog indicates old filenames are still accepted, but newly installed servers come with the new file names.
Closes #1298
2021-12-16 22:23:15 -05:00
Adarnof
6688f73565 Use integer teamspeak group IDs when filtering. 2021-12-15 23:54:53 -05:00
Adarnof
72740b9e4d Prevent assignment of reserved groups to AuthTSgroup mappings.
Implemented in TS group updates to prevent their creation / delete once
reserved, and the admin site for when a reserved group name is created
but before the TS group sync occurs.
2021-12-08 23:41:10 -05:00
Adarnof
d11832913d Implement reserved group names in Teamspeak3 service module.
Closes #1302
2021-12-01 00:50:29 -05:00
Ariel Rin
dfe62db8ee add datatables savestate feature 2021-11-27 23:02:33 +10:00
135 changed files with 6346 additions and 2649 deletions

1
.gitignore vendored
View File

@@ -76,3 +76,4 @@ celerybeat-schedule
.flake8
.pylintrc
Makefile
.isort.cfg

View File

@@ -54,7 +54,9 @@ test-3.7-core:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.8-core:
<<: *only-default
@@ -64,7 +66,9 @@ test-3.8-core:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.9-core:
<<: *only-default
@@ -74,7 +78,9 @@ test-3.9-core:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.10-core:
<<: *only-default
@@ -84,7 +90,9 @@ test-3.10-core:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.11-core:
<<: *only-default
@@ -94,7 +102,9 @@ test-3.11-core:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
allow_failure: true
test-3.7-all:
@@ -105,7 +115,9 @@ test-3.7-all:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.8-all:
<<: *only-default
@@ -115,7 +127,9 @@ test-3.8-all:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.9-all:
<<: *only-default
@@ -125,7 +139,9 @@ test-3.9-all:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.10-all:
<<: *only-default
@@ -135,7 +151,9 @@ test-3.10-all:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
test-3.11-all:
<<: *only-default
@@ -145,9 +163,17 @@ test-3.11-all:
artifacts:
when: always
reports:
cobertura: coverage.xml
coverage_report:
coverage_format: cobertura
path: coverage.xml
allow_failure: true
test-docs:
<<: *only-default
image: python:3.9-bullseye
script:
- tox -e docs
deploy_production:
stage: deploy
image: python:3.10-bullseye

View File

@@ -1,7 +1,7 @@
# This will make sure the app is always imported when
# Django starts so that shared_task will use this app.
__version__ = '2.9.4'
__version__ = '2.15.1'
__title__ = 'Alliance Auth'
__url__ = 'https://gitlab.com/allianceauth/allianceauth'
NAME = f'{__title__} v{__version__}'

View File

@@ -1,5 +1,6 @@
from bs4 import BeautifulSoup
from django.conf import settings
from django.utils.deprecation import MiddlewareMixin
from .models import AnalyticsTokens, AnalyticsIdentifier
from .tasks import send_ga_tracking_web_view
@@ -10,6 +11,8 @@ import re
class AnalyticsMiddleware(MiddlewareMixin):
def process_response(self, request, response):
"""Django Middleware: Process Page Views and creates Analytics Celery Tasks"""
if getattr(settings, "ANALYTICS_DISABLED", False):
return response
analyticstokens = AnalyticsTokens.objects.all()
client_id = AnalyticsIdentifier.objects.get(id=1).identifier.hex
try:

View File

@@ -1,11 +1,11 @@
# Generated by Django 3.1.13 on 2021-10-15 05:02
from django.core.exceptions import ObjectDoesNotExist
from django.db import migrations
def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
# We can't import the Person model directly as it may be a newer
# version than this migration expects. We use the historical version.
# Add /admin/ and /user_notifications_count/ path to ignore
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
admin = AnalyticsPath.objects.create(ignore_path=r"^\/admin\/.*")
@@ -17,8 +17,19 @@ def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
def undo_modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
# nothing should need to migrate away here?
return True
#
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
token = Tokens.objects.get(token="UA-186249766-2")
try:
admin = AnalyticsPath.objects.get(ignore_path=r"^\/admin\/.*", analyticstokens=token)
user_notifications_count = AnalyticsPath.objects.get(ignore_path=r"^\/user_notifications_count\/.*", analyticstokens=token)
admin.delete()
user_notifications_count.delete()
except ObjectDoesNotExist:
# Its fine if it doesnt exist, we just dont want them building up when re-migrating
pass
class Migration(migrations.Migration):

View File

@@ -0,0 +1,40 @@
# Generated by Django 3.2.8 on 2021-10-19 01:47
from django.core.exceptions import ObjectDoesNotExist
from django.db import migrations
def modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
# Add the /account/activate path to ignore
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
account_activate = AnalyticsPath.objects.create(ignore_path=r"^\/account\/activate\/.*")
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
token = Tokens.objects.get(token="UA-186249766-2")
token.ignore_paths.add(account_activate)
def undo_modify_aa_team_token_add_page_ignore_paths(apps, schema_editor):
#
AnalyticsPath = apps.get_model('analytics', 'AnalyticsPath')
Tokens = apps.get_model('analytics', 'AnalyticsTokens')
token = Tokens.objects.get(token="UA-186249766-2")
try:
account_activate = AnalyticsPath.objects.get(ignore_path=r"^\/account\/activate\/.*", analyticstokens=token)
account_activate.delete()
except ObjectDoesNotExist:
# Its fine if it doesnt exist, we just dont want them building up when re-migrating
pass
class Migration(migrations.Migration):
dependencies = [
('analytics', '0005_alter_analyticspath_ignore_path'),
]
operations = [
migrations.RunPython(modify_aa_team_token_add_page_ignore_paths, undo_modify_aa_team_token_add_page_ignore_paths)
]

View File

@@ -1,7 +1,8 @@
from allianceauth.analytics.tasks import analytics_event
from celery.signals import task_failure, task_success
import logging
from celery.signals import task_failure, task_success
from django.conf import settings
from allianceauth.analytics.tasks import analytics_event
logger = logging.getLogger(__name__)
@@ -11,6 +12,8 @@ def process_failure_signal(
sender, task_id, signal,
args, kwargs, einfo, **kw):
logger.debug("Celery task_failure signal %s" % sender.__class__.__name__)
if getattr(settings, "ANALYTICS_DISABLED", False):
return
category = sender.__module__
@@ -30,6 +33,8 @@ def process_failure_signal(
@task_success.connect
def celery_success_signal(sender, result=None, **kw):
logger.debug("Celery task_success signal %s" % sender.__class__.__name__)
if getattr(settings, "ANALYTICS_DISABLED", False):
return
category = sender.__module__

View File

@@ -21,8 +21,8 @@ if getattr(settings, "ANALYTICS_ENABLE_DEBUG", False) and settings.DEBUG:
# Force sending of analytics data during in a debug/test environemt
# Usefull for developers working on this feature.
logger.warning(
"You have 'ANALYTICS_ENABLE_DEBUG' Enabled! "
"This debug instance will send analytics data!")
"You have 'ANALYTICS_ENABLE_DEBUG' Enabled! "
"This debug instance will send analytics data!")
DEBUG_URL = COLLECTION_URL
ANALYTICS_URL = COLLECTION_URL
@@ -40,13 +40,12 @@ def analytics_event(category: str,
Send a Google Analytics Event for each token stored
Includes check for if its enabled/disabled
Parameters
-------
`category` (str): Celery Namespace
`action` (str): Task Name
`label` (str): Optional, Task Success/Exception
`value` (int): Optional, If bulk, Query size, can be a binary True/False
`event_type` (str): Optional, Celery or Stats only, Default to Celery
Args:
`category` (str): Celery Namespace
`action` (str): Task Name
`label` (str): Optional, Task Success/Exception
`value` (int): Optional, If bulk, Query size, can be a binary True/False
`event_type` (str): Optional, Celery or Stats only, Default to Celery
"""
analyticstokens = AnalyticsTokens.objects.all()
client_id = AnalyticsIdentifier.objects.get(id=1).identifier.hex
@@ -60,20 +59,21 @@ def analytics_event(category: str,
if allowed is True:
tracking_id = token.token
send_ga_tracking_celery_event.s(tracking_id=tracking_id,
client_id=client_id,
category=category,
action=action,
label=label,
value=value).\
apply_async(priority=9)
send_ga_tracking_celery_event.s(
tracking_id=tracking_id,
client_id=client_id,
category=category,
action=action,
label=label,
value=value).apply_async(priority=9)
@shared_task()
def analytics_daily_stats():
"""Celery Task: Do not call directly
Gathers a series of daily statistics and sends analytics events containing them"""
Gathers a series of daily statistics and sends analytics events containing them
"""
users = install_stat_users()
tokens = install_stat_tokens()
addons = install_stat_addons()

View File

@@ -0,0 +1,109 @@
from unittest.mock import patch
from urllib.parse import parse_qs
import requests_mock
from django.test import override_settings
from allianceauth.analytics.tasks import ANALYTICS_URL
from allianceauth.eveonline.tasks import update_character
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
@override_settings(CELERY_ALWAYS_EAGER=True)
@requests_mock.mock()
class TestAnalyticsForViews(NoSocketsTestCase):
@override_settings(ANALYTICS_DISABLED=False)
def test_should_run_analytics(self, requests_mocker):
# given
requests_mocker.post(ANALYTICS_URL)
user = AuthUtils.create_user("Bruce Wayne")
self.client.force_login(user)
# when
response = self.client.get("/dashboard/")
# then
self.assertEqual(response.status_code, 200)
self.assertTrue(requests_mocker.called)
@override_settings(ANALYTICS_DISABLED=True)
def test_should_not_run_analytics(self, requests_mocker):
# given
requests_mocker.post(ANALYTICS_URL)
user = AuthUtils.create_user("Bruce Wayne")
self.client.force_login(user)
# when
response = self.client.get("/dashboard/")
# then
self.assertEqual(response.status_code, 200)
self.assertFalse(requests_mocker.called)
@override_settings(CELERY_ALWAYS_EAGER=True)
@requests_mock.mock()
class TestAnalyticsForTasks(NoSocketsTestCase):
@override_settings(ANALYTICS_DISABLED=False)
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
def test_should_run_analytics_for_successful_task(
self, requests_mocker, mock_update_character
):
# given
requests_mocker.post(ANALYTICS_URL)
user = AuthUtils.create_user("Bruce Wayne")
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
# when
update_character.delay(character.character_id)
# then
self.assertTrue(mock_update_character.called)
self.assertTrue(requests_mocker.called)
payload = parse_qs(requests_mocker.last_request.text)
self.assertListEqual(payload["el"], ["Success"])
@override_settings(ANALYTICS_DISABLED=True)
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
def test_should_not_run_analytics_for_successful_task(
self, requests_mocker, mock_update_character
):
# given
requests_mocker.post(ANALYTICS_URL)
user = AuthUtils.create_user("Bruce Wayne")
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
# when
update_character.delay(character.character_id)
# then
self.assertTrue(mock_update_character.called)
self.assertFalse(requests_mocker.called)
@override_settings(ANALYTICS_DISABLED=False)
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
def test_should_run_analytics_for_failed_task(
self, requests_mocker, mock_update_character
):
# given
requests_mocker.post(ANALYTICS_URL)
mock_update_character.side_effect = RuntimeError
user = AuthUtils.create_user("Bruce Wayne")
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
# when
update_character.delay(character.character_id)
# then
self.assertTrue(mock_update_character.called)
self.assertTrue(requests_mocker.called)
payload = parse_qs(requests_mocker.last_request.text)
self.assertNotEqual(payload["el"], ["Success"])
@override_settings(ANALYTICS_DISABLED=True)
@patch("allianceauth.eveonline.models.EveCharacter.objects.update_character")
def test_should_not_run_analytics_for_failed_task(
self, requests_mocker, mock_update_character
):
# given
requests_mocker.post(ANALYTICS_URL)
mock_update_character.side_effect = RuntimeError
user = AuthUtils.create_user("Bruce Wayne")
character = AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
# when
update_character.delay(character.character_id)
# then
self.assertTrue(mock_update_character.called)
self.assertFalse(requests_mocker.called)

View File

@@ -1,12 +1,22 @@
import requests_mock
from django.test.utils import override_settings
from allianceauth.analytics.tasks import (
analytics_event,
send_ga_tracking_celery_event,
send_ga_tracking_web_view)
from django.test.testcases import TestCase
from allianceauth.utils.testing import NoSocketsTestCase
class TestAnalyticsTasks(TestCase):
def test_analytics_event(self):
GOOGLE_ANALYTICS_DEBUG_URL = 'https://www.google-analytics.com/debug/collect'
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
@requests_mock.Mocker()
class TestAnalyticsTasks(NoSocketsTestCase):
def test_analytics_event(self, requests_mocker):
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
analytics_event(
category='allianceauth.analytics',
action='send_tests',
@@ -14,15 +24,19 @@ class TestAnalyticsTasks(TestCase):
value=1,
event_type='Stats')
def test_send_ga_tracking_web_view_sent(self):
# This test sends if the event SENDS to google
# Not if it was successful
def test_send_ga_tracking_web_view_sent(self, requests_mocker):
"""This test sends if the event SENDS to google.
Not if it was successful.
"""
# given
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
tracking_id = 'UA-186249766-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
page = '/index/'
title = 'Hello World'
locale = 'en'
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
# when
response = send_ga_tracking_web_view(
tracking_id,
client_id,
@@ -30,15 +44,23 @@ class TestAnalyticsTasks(TestCase):
title,
locale,
useragent)
# then
self.assertEqual(response.status_code, 200)
def test_send_ga_tracking_web_view_success(self):
def test_send_ga_tracking_web_view_success(self, requests_mocker):
# given
requests_mocker.register_uri(
'POST',
GOOGLE_ANALYTICS_DEBUG_URL,
json={"hitParsingResult":[{'valid': True}]}
)
tracking_id = 'UA-186249766-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
page = '/index/'
title = 'Hello World'
locale = 'en'
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
# when
json_response = send_ga_tracking_web_view(
tracking_id,
client_id,
@@ -46,15 +68,42 @@ class TestAnalyticsTasks(TestCase):
title,
locale,
useragent).json()
# then
self.assertTrue(json_response["hitParsingResult"][0]["valid"])
def test_send_ga_tracking_web_view_invalid_token(self):
def test_send_ga_tracking_web_view_invalid_token(self, requests_mocker):
# given
requests_mocker.register_uri(
'POST',
GOOGLE_ANALYTICS_DEBUG_URL,
json={
"hitParsingResult":[
{
'valid': False,
'parserMessage': [
{
'messageType': 'INFO',
'description': 'IP Address from this hit was anonymized to 1.132.110.0.',
'messageCode': 'VALUE_MODIFIED'
},
{
'messageType': 'ERROR',
'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.",
'messageCode': 'VALUE_INVALID', 'parameter': 'tid'
}
],
'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'
}
]
}
)
tracking_id = 'UA-IntentionallyBadTrackingID-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
page = '/index/'
title = 'Hello World'
locale = 'en'
useragent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36"
# when
json_response = send_ga_tracking_web_view(
tracking_id,
client_id,
@@ -62,18 +111,25 @@ class TestAnalyticsTasks(TestCase):
title,
locale,
useragent).json()
# then
self.assertFalse(json_response["hitParsingResult"][0]["valid"])
self.assertEqual(json_response["hitParsingResult"][0]["parserMessage"][1]["description"], "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.")
self.assertEqual(
json_response["hitParsingResult"][0]["parserMessage"][1]["description"],
"The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details."
)
# [{'valid': False, 'parserMessage': [{'messageType': 'INFO', 'description': 'IP Address from this hit was anonymized to 1.132.110.0.', 'messageCode': 'VALUE_MODIFIED'}, {'messageType': 'ERROR', 'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.", 'messageCode': 'VALUE_INVALID', 'parameter': 'tid'}], 'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'}]
def test_send_ga_tracking_celery_event_sent(self):
def test_send_ga_tracking_celery_event_sent(self, requests_mocker):
# given
requests_mocker.register_uri('POST', GOOGLE_ANALYTICS_DEBUG_URL)
tracking_id = 'UA-186249766-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
category = 'test'
action = 'test'
label = 'test'
value = '1'
# when
response = send_ga_tracking_celery_event(
tracking_id,
client_id,
@@ -81,15 +137,23 @@ class TestAnalyticsTasks(TestCase):
action,
label,
value)
# then
self.assertEqual(response.status_code, 200)
def test_send_ga_tracking_celery_event_success(self):
def test_send_ga_tracking_celery_event_success(self, requests_mocker):
# given
requests_mocker.register_uri(
'POST',
GOOGLE_ANALYTICS_DEBUG_URL,
json={"hitParsingResult":[{'valid': True}]}
)
tracking_id = 'UA-186249766-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
category = 'test'
action = 'test'
label = 'test'
value = '1'
# when
json_response = send_ga_tracking_celery_event(
tracking_id,
client_id,
@@ -97,15 +161,42 @@ class TestAnalyticsTasks(TestCase):
action,
label,
value).json()
# then
self.assertTrue(json_response["hitParsingResult"][0]["valid"])
def test_send_ga_tracking_celery_event_invalid_token(self):
def test_send_ga_tracking_celery_event_invalid_token(self, requests_mocker):
# given
requests_mocker.register_uri(
'POST',
GOOGLE_ANALYTICS_DEBUG_URL,
json={
"hitParsingResult":[
{
'valid': False,
'parserMessage': [
{
'messageType': 'INFO',
'description': 'IP Address from this hit was anonymized to 1.132.110.0.',
'messageCode': 'VALUE_MODIFIED'
},
{
'messageType': 'ERROR',
'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.",
'messageCode': 'VALUE_INVALID', 'parameter': 'tid'
}
],
'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'
}
]
}
)
tracking_id = 'UA-IntentionallyBadTrackingID-2'
client_id = 'ab33e241fbf042b6aa77c7655a768af7'
category = 'test'
action = 'test'
label = 'test'
value = '1'
# when
json_response = send_ga_tracking_celery_event(
tracking_id,
client_id,
@@ -113,7 +204,9 @@ class TestAnalyticsTasks(TestCase):
action,
label,
value).json()
# then
self.assertFalse(json_response["hitParsingResult"][0]["valid"])
self.assertEqual(json_response["hitParsingResult"][0]["parserMessage"][1]["description"], "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.")
# [{'valid': False, 'parserMessage': [{'messageType': 'INFO', 'description': 'IP Address from this hit was anonymized to 1.132.110.0.', 'messageCode': 'VALUE_MODIFIED'}, {'messageType': 'ERROR', 'description': "The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details.", 'messageCode': 'VALUE_INVALID', 'parameter': 'tid'}], 'hit': '/debug/collect?v=1&tid=UA-IntentionallyBadTrackingID-2&cid=ab33e241fbf042b6aa77c7655a768af7&t=pageview&dp=/index/&dt=Hello World&ul=en&ua=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36&aip=1&an=allianceauth&av=2.9.0a2'}]
self.assertEqual(
json_response["hitParsingResult"][0]["parserMessage"][1]["description"],
"The value provided for parameter 'tid' is invalid. Please see http://goo.gl/a8d4RP#tid for details."
)

View File

@@ -1,26 +1,44 @@
from django.contrib import admin
from django.contrib.auth.admin import UserAdmin as BaseUserAdmin
from django.contrib.auth.models import User as BaseUser, \
Permission as BasePermission, Group
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission as BasePermission
from django.contrib.auth.models import User as BaseUser
from django.db.models import Count, Q
from allianceauth.services.hooks import ServicesHook
from django.db.models.signals import pre_save, post_save, pre_delete, \
post_delete, m2m_changed
from django.db.models.functions import Lower
from django.db.models.signals import (
m2m_changed,
post_delete,
post_save,
pre_delete,
pre_save
)
from django.dispatch import receiver
from django.forms import ModelForm
from django.utils.html import format_html
from django.urls import reverse
from django.utils.html import format_html
from django.utils.text import slugify
from allianceauth.authentication.models import State, get_guest_state,\
CharacterOwnership, UserProfile, OwnershipRecord
from allianceauth.hooks import get_hooks
from allianceauth.eveonline.models import EveCharacter, EveCorporationInfo,\
EveAllianceInfo, EveFactionInfo
from allianceauth.authentication.models import (
CharacterOwnership,
OwnershipRecord,
State,
UserProfile,
get_guest_state
)
from allianceauth.eveonline.models import (
EveAllianceInfo,
EveCharacter,
EveCorporationInfo,
EveFactionInfo
)
from allianceauth.eveonline.tasks import update_character
from .app_settings import AUTHENTICATION_ADMIN_USERS_MAX_GROUPS, \
AUTHENTICATION_ADMIN_USERS_MAX_CHARS
from allianceauth.hooks import get_hooks
from allianceauth.services.hooks import ServicesHook
from .app_settings import (
AUTHENTICATION_ADMIN_USERS_MAX_CHARS,
AUTHENTICATION_ADMIN_USERS_MAX_GROUPS
)
from .forms import UserChangeForm, UserProfileForm
def make_service_hooks_update_groups_action(service):
@@ -59,19 +77,10 @@ def make_service_hooks_sync_nickname_action(service):
return sync_nickname
class QuerysetModelForm(ModelForm):
# allows specifying FK querysets through kwarg
def __init__(self, querysets=None, *args, **kwargs):
querysets = querysets or {}
super().__init__(*args, **kwargs)
for field, qs in querysets.items():
self.fields[field].queryset = qs
class UserProfileInline(admin.StackedInline):
model = UserProfile
readonly_fields = ('state',)
form = QuerysetModelForm
form = UserProfileForm
verbose_name = ''
verbose_name_plural = 'Profile'
@@ -99,6 +108,7 @@ class UserProfileInline(admin.StackedInline):
return False
@admin.display(description="")
def user_profile_pic(obj):
"""profile pic column data for user objects
@@ -111,13 +121,10 @@ def user_profile_pic(obj):
'<img src="{}" class="img-circle">',
user_obj.profile.main_character.portrait_url(size=32)
)
else:
return None
user_profile_pic.short_description = ''
return None
@admin.display(description="user / main", ordering="username")
def user_username(obj):
"""user column data for user objects
@@ -139,18 +146,17 @@ def user_username(obj):
user_obj.username,
user_obj.profile.main_character.character_name
)
else:
return format_html(
'<strong><a href="{}">{}</a></strong>',
link,
user_obj.username,
)
user_username.short_description = 'user / main'
user_username.admin_order_field = 'username'
return format_html(
'<strong><a href="{}">{}</a></strong>',
link,
user_obj.username,
)
@admin.display(
description="Corporation / Alliance (Main)",
ordering="profile__main_character__corporation_name"
)
def user_main_organization(obj):
"""main organization column data for user objects
@@ -159,21 +165,15 @@ def user_main_organization(obj):
"""
user_obj = obj.user if hasattr(obj, 'user') else obj
if not user_obj.profile.main_character:
result = ''
else:
result = user_obj.profile.main_character.corporation_name
if user_obj.profile.main_character.alliance_id:
result += f'<br>{user_obj.profile.main_character.alliance_name}'
elif user_obj.profile.main_character.faction_name:
result += f'<br>{user_obj.profile.main_character.faction_name}'
return ''
result = user_obj.profile.main_character.corporation_name
if user_obj.profile.main_character.alliance_id:
result += f'<br>{user_obj.profile.main_character.alliance_name}'
elif user_obj.profile.main_character.faction_name:
result += f'<br>{user_obj.profile.main_character.faction_name}'
return format_html(result)
user_main_organization.short_description = 'Corporation / Alliance (Main)'
user_main_organization.admin_order_field = \
'profile__main_character__corporation_name'
class MainCorporationsFilter(admin.SimpleListFilter):
"""Custom filter to filter on corporations from mains only
@@ -196,15 +196,13 @@ class MainCorporationsFilter(admin.SimpleListFilter):
def queryset(self, request, qs):
if self.value() is None:
return qs.all()
else:
if qs.model == User:
return qs.filter(
profile__main_character__corporation_id=self.value()
)
else:
return qs.filter(
user__profile__main_character__corporation_id=self.value()
)
if qs.model == User:
return qs.filter(
profile__main_character__corporation_id=self.value()
)
return qs.filter(
user__profile__main_character__corporation_id=self.value()
)
class MainAllianceFilter(admin.SimpleListFilter):
@@ -217,12 +215,14 @@ class MainAllianceFilter(admin.SimpleListFilter):
parameter_name = 'main_alliance_id__exact'
def lookups(self, request, model_admin):
qs = EveCharacter.objects\
.exclude(alliance_id=None)\
.exclude(userprofile=None)\
.values('alliance_id', 'alliance_name')\
.distinct()\
qs = (
EveCharacter.objects
.exclude(alliance_id=None)
.exclude(userprofile=None)
.values('alliance_id', 'alliance_name')
.distinct()
.order_by(Lower('alliance_name'))
)
return tuple(
(x['alliance_id'], x['alliance_name']) for x in qs
)
@@ -230,13 +230,11 @@ class MainAllianceFilter(admin.SimpleListFilter):
def queryset(self, request, qs):
if self.value() is None:
return qs.all()
else:
if qs.model == User:
return qs.filter(profile__main_character__alliance_id=self.value())
else:
return qs.filter(
user__profile__main_character__alliance_id=self.value()
)
if qs.model == User:
return qs.filter(profile__main_character__alliance_id=self.value())
return qs.filter(
user__profile__main_character__alliance_id=self.value()
)
class MainFactionFilter(admin.SimpleListFilter):
@@ -249,12 +247,14 @@ class MainFactionFilter(admin.SimpleListFilter):
parameter_name = 'main_faction_id__exact'
def lookups(self, request, model_admin):
qs = EveCharacter.objects\
.exclude(faction_id=None)\
.exclude(userprofile=None)\
.values('faction_id', 'faction_name')\
.distinct()\
qs = (
EveCharacter.objects
.exclude(faction_id=None)
.exclude(userprofile=None)
.values('faction_id', 'faction_name')
.distinct()
.order_by(Lower('faction_name'))
)
return tuple(
(x['faction_id'], x['faction_name']) for x in qs
)
@@ -262,15 +262,14 @@ class MainFactionFilter(admin.SimpleListFilter):
def queryset(self, request, qs):
if self.value() is None:
return qs.all()
else:
if qs.model == User:
return qs.filter(profile__main_character__faction_id=self.value())
else:
return qs.filter(
user__profile__main_character__faction_id=self.value()
)
if qs.model == User:
return qs.filter(profile__main_character__faction_id=self.value())
return qs.filter(
user__profile__main_character__faction_id=self.value()
)
@admin.display(description="Update main character model from ESI")
def update_main_character_model(modeladmin, request, queryset):
tasks_count = 0
for obj in queryset:
@@ -279,21 +278,48 @@ def update_main_character_model(modeladmin, request, queryset):
tasks_count += 1
modeladmin.message_user(
request,
f'Update from ESI started for {tasks_count} characters'
request, f'Update from ESI started for {tasks_count} characters'
)
update_main_character_model.short_description = \
'Update main character model from ESI'
class UserAdmin(BaseUserAdmin):
"""Extending Django's UserAdmin model
Behavior of groups and characters columns can be configured via settings
"""
inlines = BaseUserAdmin.inlines + [UserProfileInline]
ordering = ('username', )
list_select_related = ('profile__state', 'profile__main_character')
show_full_result_count = True
list_display = (
user_profile_pic,
user_username,
'_state',
'_groups',
user_main_organization,
'_characters',
'is_active',
'date_joined',
'_role'
)
list_display_links = None
list_filter = (
'profile__state',
'groups',
MainCorporationsFilter,
MainAllianceFilter,
MainFactionFilter,
'is_active',
'date_joined',
'is_staff',
'is_superuser'
)
search_fields = ('username', 'character_ownerships__character__character_name')
readonly_fields = ('date_joined', 'last_login')
filter_horizontal = ('groups', 'user_permissions',)
form = UserChangeForm
class Media:
css = {
"all": ("authentication/css/admin.css",)
@@ -303,9 +329,21 @@ class UserAdmin(BaseUserAdmin):
qs = super().get_queryset(request)
return qs.prefetch_related("character_ownerships__character", "groups")
def get_actions(self, request):
actions = super(BaseUserAdmin, self).get_actions(request)
def get_form(self, request, obj=None, **kwargs):
"""Inject current request into change form object."""
MyForm = super().get_form(request, obj, **kwargs)
if obj:
class MyFormInjected(MyForm):
def __new__(cls, *args, **kwargs):
kwargs['request'] = request
return MyForm(*args, **kwargs)
return MyFormInjected
return MyForm
def get_actions(self, request):
actions = super().get_actions(request)
actions[update_main_character_model.__name__] = (
update_main_character_model,
update_main_character_model.__name__,
@@ -349,38 +387,6 @@ class UserAdmin(BaseUserAdmin):
)
return result
inlines = BaseUserAdmin.inlines + [UserProfileInline]
ordering = ('username', )
list_select_related = ('profile__state', 'profile__main_character')
show_full_result_count = True
list_display = (
user_profile_pic,
user_username,
'_state',
'_groups',
user_main_organization,
'_characters',
'is_active',
'date_joined',
'_role'
)
list_display_links = None
list_filter = (
'profile__state',
'groups',
MainCorporationsFilter,
MainAllianceFilter,
MainFactionFilter,
'is_active',
'date_joined',
'is_staff',
'is_superuser'
)
search_fields = (
'username',
'character_ownerships__character__character_name'
)
def _characters(self, obj):
character_ownerships = list(obj.character_ownerships.all())
characters = [obj.character.character_name for obj in character_ownerships]
@@ -389,22 +395,16 @@ class UserAdmin(BaseUserAdmin):
AUTHENTICATION_ADMIN_USERS_MAX_CHARS
)
_characters.short_description = 'characters'
@admin.display(ordering="profile__state")
def _state(self, obj):
return obj.profile.state.name
_state.short_description = 'state'
_state.admin_order_field = 'profile__state'
def _groups(self, obj):
my_groups = sorted(group.name for group in list(obj.groups.all()))
return self._list_2_html_w_tooltips(
my_groups, AUTHENTICATION_ADMIN_USERS_MAX_GROUPS
)
_groups.short_description = 'groups'
def _role(self, obj):
if obj.is_superuser:
role = 'Superuser'
@@ -414,8 +414,6 @@ class UserAdmin(BaseUserAdmin):
role = 'User'
return role
_role.short_description = 'role'
def has_change_permission(self, request, obj=None):
return request.user.has_perm('auth.change_user')
@@ -425,12 +423,28 @@ class UserAdmin(BaseUserAdmin):
def has_delete_permission(self, request, obj=None):
return request.user.has_perm('auth.delete_user')
def get_object(self, *args , **kwargs):
obj = super().get_object(*args , **kwargs)
self.obj = obj # storing current object for use in formfield_for_manytomany
return obj
def formfield_for_manytomany(self, db_field, request, **kwargs):
"""overriding this formfield to have sorted lists in the form"""
if db_field.name == "groups":
kwargs["queryset"] = Group.objects.all().order_by(Lower('name'))
groups_qs = Group.objects.filter(authgroup__states__isnull=True)
obj_state = self.obj.profile.state
if obj_state:
matching_groups_qs = Group.objects.filter(authgroup__states=obj_state)
groups_qs = groups_qs | matching_groups_qs
kwargs["queryset"] = groups_qs.order_by(Lower("name"))
return super().formfield_for_manytomany(db_field, request, **kwargs)
def get_readonly_fields(self, request, obj=None):
if obj and not request.user.is_superuser:
return self.readonly_fields + (
"is_staff", "is_superuser", "user_permissions"
)
return self.readonly_fields
@admin.register(State)
class StateAdmin(admin.ModelAdmin):
@@ -441,10 +455,9 @@ class StateAdmin(admin.ModelAdmin):
qs = super().get_queryset(request)
return qs.annotate(user_count=Count("userprofile__id"))
@admin.display(description="Users", ordering="user_count")
def _user_count(self, obj):
return obj.user_count
_user_count.short_description = 'Users'
_user_count.admin_order_field = 'user_count'
fieldsets = (
(None, {
@@ -500,13 +513,13 @@ class StateAdmin(admin.ModelAdmin):
)
return super().get_fieldsets(request, obj=obj)
def get_readonly_fields(self, request, obj=None):
if not request.user.is_superuser:
return self.readonly_fields + ("permissions",)
return self.readonly_fields
class BaseOwnershipAdmin(admin.ModelAdmin):
class Media:
css = {
"all": ("authentication/css/admin.css",)
}
list_select_related = (
'user__profile__state', 'user__profile__main_character', 'character')
list_display = (
@@ -527,6 +540,11 @@ class BaseOwnershipAdmin(admin.ModelAdmin):
MainAllianceFilter,
)
class Media:
css = {
"all": ("authentication/css/admin.css",)
}
def get_readonly_fields(self, request, obj=None):
if obj and obj.pk:
return 'owner_hash', 'character'

View File

@@ -3,10 +3,14 @@ from django.core.checks import register, Tags
class AuthenticationConfig(AppConfig):
name = 'allianceauth.authentication'
label = 'authentication'
name = "allianceauth.authentication"
label = "authentication"
def ready(self):
super().ready()
from allianceauth.authentication import checks, signals
from allianceauth.authentication import checks, signals # noqa: F401
from allianceauth.authentication.task_statistics import (
signals as celery_signals,
)
register(Tags.security)(checks.check_login_scopes_setting)
celery_signals.reset_counters()

View File

@@ -1,8 +1,66 @@
from django import forms
from django.utils.translation import ugettext_lazy as _
from django.contrib.auth.forms import UserChangeForm as BaseUserChangeForm
from django.contrib.auth.models import Group
from django.core.exceptions import ValidationError
from django.forms import ModelForm
from django.utils.translation import gettext_lazy as _
from allianceauth.authentication.models import User
class RegistrationForm(forms.Form):
email = forms.EmailField(label=_('Email'), max_length=254, required=True)
class _meta:
model = User
class UserProfileForm(ModelForm):
"""Allows specifying FK querysets through kwarg"""
def __init__(self, querysets=None, *args, **kwargs):
querysets = querysets or {}
super().__init__(*args, **kwargs)
for field, qs in querysets.items():
self.fields[field].queryset = qs
class UserChangeForm(BaseUserChangeForm):
"""Add custom cleaning to UserChangeForm"""
def __init__(self, *args, **kwargs):
self.request = kwargs.pop("request") # Inject current request into form object
super().__init__(*args, **kwargs)
def clean(self):
cleaned_data = super().clean()
if not self.request.user.is_superuser:
if self.instance:
current_restricted = set(
self.instance.groups.filter(
authgroup__restricted=True
).values_list("pk", flat=True)
)
else:
current_restricted = set()
new_restricted = set(
cleaned_data["groups"].filter(
authgroup__restricted=True
).values_list("pk", flat=True)
)
if current_restricted != new_restricted:
restricted_removed = current_restricted - new_restricted
restricted_added = new_restricted - current_restricted
restricted_changed = restricted_removed | restricted_added
restricted_names_qs = Group.objects.filter(
pk__in=restricted_changed
).values_list("name", flat=True)
restricted_names = ",".join(list(restricted_names_qs))
raise ValidationError(
{
"groups": _(
"You are not allowed to add or remove these "
"restricted groups: %s" % restricted_names
)
}
)

View File

@@ -0,0 +1,40 @@
from collections import namedtuple
import datetime as dt
from .event_series import EventSeries
"""Global series for counting task events."""
succeeded_tasks = EventSeries("SUCCEEDED_TASKS")
retried_tasks = EventSeries("RETRIED_TASKS")
failed_tasks = EventSeries("FAILED_TASKS")
_TaskCounts = namedtuple(
"_TaskCounts", ["succeeded", "retried", "failed", "total", "earliest_task", "hours"]
)
def dashboard_results(hours: int) -> _TaskCounts:
"""Counts of all task events within the given timeframe."""
def earliest_if_exists(events: EventSeries, earliest: dt.datetime) -> list:
my_earliest = events.first_event(earliest=earliest)
return [my_earliest] if my_earliest else []
earliest = dt.datetime.utcnow() - dt.timedelta(hours=hours)
earliest_events = list()
succeeded_count = succeeded_tasks.count(earliest=earliest)
earliest_events += earliest_if_exists(succeeded_tasks, earliest)
retried_count = retried_tasks.count(earliest=earliest)
earliest_events += earliest_if_exists(retried_tasks, earliest)
failed_count = failed_tasks.count(earliest=earliest)
earliest_events += earliest_if_exists(failed_tasks, earliest)
return _TaskCounts(
succeeded=succeeded_count,
retried=retried_count,
failed=failed_count,
total=succeeded_count + retried_count + failed_count,
earliest_task=min(earliest_events) if earliest_events else None,
hours=hours,
)

View File

@@ -0,0 +1,130 @@
import datetime as dt
import logging
from typing import List, Optional
from pytz import utc
from redis import Redis, RedisError
from allianceauth.utils.cache import get_redis_client
logger = logging.getLogger(__name__)
class _RedisStub:
"""Stub of a Redis client.
It's purpose is to prevent EventSeries objects from trying to access Redis
when it is not available. e.g. when the Sphinx docs are rendered by readthedocs.org.
"""
def delete(self, *args, **kwargs):
pass
def incr(self, *args, **kwargs):
return 0
def zadd(self, *args, **kwargs):
pass
def zcount(self, *args, **kwargs):
pass
def zrangebyscore(self, *args, **kwargs):
pass
class EventSeries:
"""API for recording and analyzing a series of events."""
_ROOT_KEY = "ALLIANCEAUTH_EVENT_SERIES"
def __init__(self, key_id: str, redis: Redis = None) -> None:
self._redis = get_redis_client() if not redis else redis
try:
if not self._redis.ping():
raise RuntimeError()
except (AttributeError, RedisError, RuntimeError):
logger.exception(
"Failed to establish a connection with Redis. "
"This EventSeries object is disabled.",
)
self._redis = _RedisStub()
self._key_id = str(key_id)
self.clear()
@property
def is_disabled(self):
"""True when this object is disabled, e.g. Redis was not available at startup."""
return isinstance(self._redis, _RedisStub)
@property
def _key_counter(self):
return f"{self._ROOT_KEY}_{self._key_id}_COUNTER"
@property
def _key_sorted_set(self):
return f"{self._ROOT_KEY}_{self._key_id}_SORTED_SET"
def add(self, event_time: dt.datetime = None) -> None:
"""Add event.
Args:
- event_time: timestamp of event. Will use current time if not specified.
"""
if not event_time:
event_time = dt.datetime.utcnow()
id = self._redis.incr(self._key_counter)
self._redis.zadd(self._key_sorted_set, {id: event_time.timestamp()})
def all(self) -> List[dt.datetime]:
"""List of all known events."""
return [
event[1]
for event in self._redis.zrangebyscore(
self._key_sorted_set,
"-inf",
"+inf",
withscores=True,
score_cast_func=self._cast_scores_to_dt,
)
]
def clear(self) -> None:
"""Clear all events."""
self._redis.delete(self._key_sorted_set)
self._redis.delete(self._key_counter)
def count(self, earliest: dt.datetime = None, latest: dt.datetime = None) -> int:
"""Count of events, can be restricted to given timeframe.
Args:
- earliest: Date of first events to count(inclusive), or -infinite if not specified
- latest: Date of last events to count(inclusive), or +infinite if not specified
"""
min = "-inf" if not earliest else earliest.timestamp()
max = "+inf" if not latest else latest.timestamp()
return self._redis.zcount(self._key_sorted_set, min=min, max=max)
def first_event(self, earliest: dt.datetime = None) -> Optional[dt.datetime]:
"""Date/Time of first event. Returns `None` if series has no events.
Args:
- earliest: Date of first events to count(inclusive), or any if not specified
"""
min = "-inf" if not earliest else earliest.timestamp()
event = self._redis.zrangebyscore(
self._key_sorted_set,
min,
"+inf",
withscores=True,
start=0,
num=1,
score_cast_func=self._cast_scores_to_dt,
)
if not event:
return None
return event[0][1]
@staticmethod
def _cast_scores_to_dt(score) -> dt.datetime:
return dt.datetime.fromtimestamp(float(score), tz=utc)

View File

@@ -0,0 +1,54 @@
from celery.signals import (
task_failure,
task_internal_error,
task_retry,
task_success,
worker_ready
)
from django.conf import settings
from .counters import failed_tasks, retried_tasks, succeeded_tasks
def reset_counters():
"""Reset all counters for the celery status."""
succeeded_tasks.clear()
failed_tasks.clear()
retried_tasks.clear()
def is_enabled() -> bool:
return not bool(
getattr(settings, "ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED", False)
)
@worker_ready.connect
def reset_counters_when_celery_restarted(*args, **kwargs):
if is_enabled():
reset_counters()
@task_success.connect
def record_task_succeeded(*args, **kwargs):
if is_enabled():
succeeded_tasks.add()
@task_retry.connect
def record_task_retried(*args, **kwargs):
if is_enabled():
retried_tasks.add()
@task_failure.connect
def record_task_failed(*args, **kwargs):
if is_enabled():
failed_tasks.add()
@task_internal_error.connect
def record_task_internal_error(*args, **kwargs):
if is_enabled():
failed_tasks.add()

View File

@@ -0,0 +1,51 @@
import datetime as dt
from django.test import TestCase
from django.utils.timezone import now
from allianceauth.authentication.task_statistics.counters import (
dashboard_results,
succeeded_tasks,
retried_tasks,
failed_tasks,
)
class TestDashboardResults(TestCase):
def test_should_return_counts_for_given_timeframe_only(self):
# given
earliest_task = now() - dt.timedelta(minutes=15)
succeeded_tasks.clear()
succeeded_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
succeeded_tasks.add(earliest_task)
succeeded_tasks.add()
succeeded_tasks.add()
retried_tasks.clear()
retried_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
retried_tasks.add(now() - dt.timedelta(seconds=30))
retried_tasks.add()
failed_tasks.clear()
failed_tasks.add(now() - dt.timedelta(hours=1, seconds=1))
failed_tasks.add()
# when
results = dashboard_results(hours=1)
# then
self.assertEqual(results.succeeded, 3)
self.assertEqual(results.retried, 2)
self.assertEqual(results.failed, 1)
self.assertEqual(results.total, 6)
self.assertEqual(results.earliest_task, earliest_task)
def test_should_work_with_no_data(self):
# given
succeeded_tasks.clear()
retried_tasks.clear()
failed_tasks.clear()
# when
results = dashboard_results(hours=1)
# then
self.assertEqual(results.succeeded, 0)
self.assertEqual(results.retried, 0)
self.assertEqual(results.failed, 0)
self.assertEqual(results.total, 0)
self.assertIsNone(results.earliest_task)

View File

@@ -0,0 +1,168 @@
import datetime as dt
from unittest.mock import patch
from pytz import utc
from redis import RedisError
from django.test import TestCase
from django.utils.timezone import now
from allianceauth.authentication.task_statistics.event_series import (
EventSeries,
_RedisStub,
)
MODULE_PATH = "allianceauth.authentication.task_statistics.event_series"
class TestEventSeries(TestCase):
def test_should_abort_without_redis_client(self):
# when
with patch(MODULE_PATH + ".get_redis_client") as mock:
mock.return_value = None
events = EventSeries("dummy")
# then
self.assertTrue(events._redis, _RedisStub)
self.assertTrue(events.is_disabled)
def test_should_disable_itself_if_redis_not_available_1(self):
# when
with patch(MODULE_PATH + ".get_redis_client") as mock_get_master_client:
mock_get_master_client.return_value.ping.side_effect = RedisError
events = EventSeries("dummy")
# then
self.assertIsInstance(events._redis, _RedisStub)
self.assertTrue(events.is_disabled)
def test_should_disable_itself_if_redis_not_available_2(self):
# when
with patch(MODULE_PATH + ".get_redis_client") as mock_get_master_client:
mock_get_master_client.return_value.ping.return_value = False
events = EventSeries("dummy")
# then
self.assertIsInstance(events._redis, _RedisStub)
self.assertTrue(events.is_disabled)
def test_should_add_event(self):
# given
events = EventSeries("dummy")
# when
events.add()
# then
result = events.all()
self.assertEqual(len(result), 1)
self.assertAlmostEqual(result[0], now(), delta=dt.timedelta(seconds=30))
def test_should_add_event_with_specified_time(self):
# given
events = EventSeries("dummy")
my_time = dt.datetime(2021, 11, 1, 12, 15, tzinfo=utc)
# when
events.add(my_time)
# then
result = events.all()
self.assertEqual(len(result), 1)
self.assertAlmostEqual(result[0], my_time, delta=dt.timedelta(seconds=30))
def test_should_count_events(self):
# given
events = EventSeries("dummy")
events.add()
events.add()
# when
result = events.count()
# then
self.assertEqual(result, 2)
def test_should_count_zero(self):
# given
events = EventSeries("dummy")
# when
result = events.count()
# then
self.assertEqual(result, 0)
def test_should_count_events_within_timeframe_1(self):
# given
events = EventSeries("dummy")
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
# when
result = events.count(
earliest=dt.datetime(2021, 12, 1, 12, 8, tzinfo=utc),
latest=dt.datetime(2021, 12, 1, 12, 17, tzinfo=utc),
)
# then
self.assertEqual(result, 2)
def test_should_count_events_within_timeframe_2(self):
# given
events = EventSeries("dummy")
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
# when
result = events.count(earliest=dt.datetime(2021, 12, 1, 12, 8))
# then
self.assertEqual(result, 3)
def test_should_count_events_within_timeframe_3(self):
# given
events = EventSeries("dummy")
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
# when
result = events.count(latest=dt.datetime(2021, 12, 1, 12, 12))
# then
self.assertEqual(result, 2)
def test_should_clear_events(self):
# given
events = EventSeries("dummy")
events.add()
events.add()
# when
events.clear()
# then
self.assertEqual(events.count(), 0)
def test_should_return_date_of_first_event(self):
# given
events = EventSeries("dummy")
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
# when
result = events.first_event()
# then
self.assertEqual(result, dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
def test_should_return_date_of_first_event_with_range(self):
# given
events = EventSeries("dummy")
events.add(dt.datetime(2021, 12, 1, 12, 0, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 15, tzinfo=utc))
events.add(dt.datetime(2021, 12, 1, 12, 30, tzinfo=utc))
# when
result = events.first_event(
earliest=dt.datetime(2021, 12, 1, 12, 8, tzinfo=utc)
)
# then
self.assertEqual(result, dt.datetime(2021, 12, 1, 12, 10, tzinfo=utc))
def test_should_return_all_events(self):
# given
events = EventSeries("dummy")
events.add()
events.add()
# when
results = events.all()
# then
self.assertEqual(len(results), 2)

View File

@@ -0,0 +1,93 @@
from unittest.mock import patch
from celery.exceptions import Retry
from django.test import TestCase, override_settings
from allianceauth.authentication.task_statistics.counters import (
failed_tasks,
retried_tasks,
succeeded_tasks,
)
from allianceauth.authentication.task_statistics.signals import (
reset_counters,
is_enabled,
)
from allianceauth.eveonline.tasks import update_character
@override_settings(
CELERY_ALWAYS_EAGER=True,ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=False
)
class TestTaskSignals(TestCase):
fixtures = ["disable_analytics"]
def test_should_record_successful_task(self):
# given
succeeded_tasks.clear()
retried_tasks.clear()
failed_tasks.clear()
# when
with patch(
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
) as mock_update:
mock_update.return_value = None
update_character.delay(1)
# then
self.assertEqual(succeeded_tasks.count(), 1)
self.assertEqual(retried_tasks.count(), 0)
self.assertEqual(failed_tasks.count(), 0)
def test_should_record_retried_task(self):
# given
succeeded_tasks.clear()
retried_tasks.clear()
failed_tasks.clear()
# when
with patch(
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
) as mock_update:
mock_update.side_effect = Retry
update_character.delay(1)
# then
self.assertEqual(succeeded_tasks.count(), 0)
self.assertEqual(failed_tasks.count(), 0)
self.assertEqual(retried_tasks.count(), 1)
def test_should_record_failed_task(self):
# given
succeeded_tasks.clear()
retried_tasks.clear()
failed_tasks.clear()
# when
with patch(
"allianceauth.eveonline.tasks.EveCharacter.objects.update_character"
) as mock_update:
mock_update.side_effect = RuntimeError
update_character.delay(1)
# then
self.assertEqual(succeeded_tasks.count(), 0)
self.assertEqual(retried_tasks.count(), 0)
self.assertEqual(failed_tasks.count(), 1)
def test_should_reset_counters(self):
# given
succeeded_tasks.add()
retried_tasks.add()
failed_tasks.add()
# when
reset_counters()
# then
self.assertEqual(succeeded_tasks.count(), 0)
self.assertEqual(retried_tasks.count(), 0)
self.assertEqual(failed_tasks.count(), 0)
class TestIsEnabled(TestCase):
@override_settings(ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=False)
def test_enabled(self):
self.assertTrue(is_enabled())
@override_settings(ALLIANCEAUTH_DASHBOARD_TASK_STATISTICS_DISABLED=True)
def test_disabled(self):
self.assertFalse(is_enabled())

View File

@@ -1,6 +1,9 @@
from bs4 import BeautifulSoup
from urllib.parse import quote
from unittest.mock import patch, MagicMock
from django_webtest import WebTest
from django.contrib.admin.sites import AdminSite
from django.contrib.auth.models import Group
from django.test import TestCase, RequestFactory, Client
@@ -188,7 +191,7 @@ class TestCaseWithTestData(TestCase):
corporation_id=5432,
corporation_name="Xavier's School for Gifted Youngsters",
corporation_ticker='MUTNT',
alliance_id = None,
alliance_id=None,
faction_id=999,
faction_name='The X-Men',
)
@@ -206,6 +209,7 @@ class TestCaseWithTestData(TestCase):
cls.user_4.profile.save()
EveFactionInfo.objects.create(faction_id=999, faction_name='The X-Men')
def make_generic_search_request(ModelClass: type, search_term: str):
User.objects.create_superuser(
username='superuser', password='secret', email='admin@example.com'
@@ -218,6 +222,7 @@ def make_generic_search_request(ModelClass: type, search_term: str):
class TestCharacterOwnershipAdmin(TestCaseWithTestData):
fixtures = ["disable_analytics"]
def setUp(self):
self.modeladmin = CharacterOwnershipAdmin(
@@ -244,6 +249,7 @@ class TestCharacterOwnershipAdmin(TestCaseWithTestData):
class TestOwnershipRecordAdmin(TestCaseWithTestData):
fixtures = ["disable_analytics"]
def setUp(self):
self.modeladmin = OwnershipRecordAdmin(
@@ -270,11 +276,12 @@ class TestOwnershipRecordAdmin(TestCaseWithTestData):
class TestStateAdmin(TestCaseWithTestData):
fixtures = ["disable_analytics"]
def setUp(self):
self.modeladmin = StateAdmin(
model=User, admin_site=AdminSite()
)
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.modeladmin = StateAdmin(model=User, admin_site=AdminSite())
def test_change_view_loads_normally(self):
User.objects.create_superuser(
@@ -299,6 +306,7 @@ class TestStateAdmin(TestCaseWithTestData):
class TestUserAdmin(TestCaseWithTestData):
fixtures = ["disable_analytics"]
def setUp(self):
self.factory = RequestFactory()
@@ -344,7 +352,7 @@ class TestUserAdmin(TestCaseWithTestData):
self.assertEqual(user_main_organization(self.user_3), expected)
def test_user_main_organization_u4(self):
expected="Xavier's School for Gifted Youngsters<br>The X-Men"
expected = "Xavier's School for Gifted Youngsters<br>The X-Men"
self.assertEqual(user_main_organization(self.user_4), expected)
def test_characters_u1(self):
@@ -537,6 +545,229 @@ class TestUserAdmin(TestCaseWithTestData):
self.assertEqual(response.status_code, expected)
class TestStateAdminChangeFormSuperuserExclusiveEdits(WebTest):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.super_admin = User.objects.create_superuser("super_admin")
cls.staff_admin = User.objects.create_user("staff_admin")
cls.staff_admin.is_staff = True
cls.staff_admin.save()
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
[
"authentication.add_state",
"authentication.change_state",
"authentication.view_state",
],
cls.staff_admin
)
cls.superuser_exclusive_fields = ["permissions",]
def test_should_show_all_fields_to_superuser_for_add(self):
# given
self.app.set_user(self.super_admin)
page = self.app.get("/admin/authentication/state/add/")
# when
form = page.forms["state_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertIn(field, form.fields)
def test_should_not_show_all_fields_to_staff_admins_for_add(self):
# given
self.app.set_user(self.staff_admin)
page = self.app.get("/admin/authentication/state/add/")
# when
form = page.forms["state_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertNotIn(field, form.fields)
def test_should_show_all_fields_to_superuser_for_change(self):
# given
self.app.set_user(self.super_admin)
state = AuthUtils.get_member_state()
page = self.app.get(f"/admin/authentication/state/{state.pk}/change/")
# when
form = page.forms["state_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertIn(field, form.fields)
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
# given
self.app.set_user(self.staff_admin)
state = AuthUtils.get_member_state()
page = self.app.get(f"/admin/authentication/state/{state.pk}/change/")
# when
form = page.forms["state_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertNotIn(field, form.fields)
class TestUserAdminChangeForm(TestCase):
fixtures = ["disable_analytics"]
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.modeladmin = UserAdmin(model=User, admin_site=AdminSite())
def test_should_show_groups_available_to_user_with_blue_state_only(self):
# given
superuser = User.objects.create_superuser("Super")
user = AuthUtils.create_user("bruce_wayne")
character = AuthUtils.add_main_character_2(
user,
name="Bruce Wayne",
character_id=1001,
corp_id=2001,
corp_name="Wayne Technologies"
)
blue_state = State.objects.get(name="Blue")
blue_state.member_characters.add(character)
member_state = AuthUtils.get_member_state()
group_1 = Group.objects.create(name="Group 1")
group_2 = Group.objects.create(name="Group 2")
group_2.authgroup.states.add(blue_state)
group_3 = Group.objects.create(name="Group 3")
group_3.authgroup.states.add(member_state)
self.client.force_login(superuser)
# when
response = self.client.get(f"/admin/authentication/user/{user.pk}/change/")
# then
self.assertEqual(response.status_code, 200)
soup = BeautifulSoup(response.rendered_content, features="html.parser")
groups_select = soup.find("select", {"id": "id_groups"}).find_all('option')
group_ids = {int(option["value"]) for option in groups_select}
self.assertSetEqual(group_ids, {group_1.pk, group_2.pk})
class TestUserAdminChangeFormSuperuserExclusiveEdits(WebTest):
fixtures = ["disable_analytics"]
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.super_admin = User.objects.create_superuser("super_admin")
cls.staff_admin = User.objects.create_user("staff_admin")
cls.staff_admin.is_staff = True
cls.staff_admin.save()
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
[
"auth.change_user",
"auth.view_user",
"authentication.change_user",
"authentication.change_userprofile",
"authentication.view_user"
],
cls.staff_admin
)
cls.superuser_exclusive_fields = [
"is_staff", "is_superuser", "user_permissions"
]
def setUp(self) -> None:
self.user = AuthUtils.create_user("bruce_wayne")
def test_should_show_all_fields_to_superuser_for_change(self):
# given
self.app.set_user(self.super_admin)
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
# when
form = page.forms["user_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertIn(field, form.fields)
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
# given
self.app.set_user(self.staff_admin)
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
# when
form = page.forms["user_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertNotIn(field, form.fields)
def test_should_allow_super_admin_to_add_restricted_group_to_user(self):
# given
self.app.set_user(self.super_admin)
group_restricted = Group.objects.create(name="restricted group")
group_restricted.authgroup.restricted = True
group_restricted.authgroup.save()
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
form = page.forms["user_form"]
# when
form["groups"].select_multiple(texts=["restricted group"])
response = form.submit("save")
# then
self.assertEqual(response.status_code, 302)
self.user.refresh_from_db()
self.assertIn(
"restricted group", self.user.groups.values_list("name", flat=True)
)
def test_should_not_allow_staff_admin_to_add_restricted_group_to_user(self):
# given
self.app.set_user(self.staff_admin)
group_restricted = Group.objects.create(name="restricted group")
group_restricted.authgroup.restricted = True
group_restricted.authgroup.save()
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
form = page.forms["user_form"]
# when
form["groups"].select_multiple(texts=["restricted group"])
response = form.submit("save")
# then
self.assertEqual(response.status_code, 200)
self.assertIn(
"You are not allowed to add or remove these restricted groups",
response.text
)
def test_should_not_allow_staff_admin_to_remove_restricted_group_from_user(self):
# given
self.app.set_user(self.staff_admin)
group_restricted = Group.objects.create(name="restricted group")
group_restricted.authgroup.restricted = True
group_restricted.authgroup.save()
self.user.groups.add(group_restricted)
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
form = page.forms["user_form"]
# when
form["groups"].select_multiple(texts=[])
response = form.submit("save")
# then
self.assertEqual(response.status_code, 200)
self.assertIn(
"You are not allowed to add or remove these restricted groups",
response.text
)
def test_should_allow_staff_admin_to_add_normal_group_to_user(self):
# given
self.app.set_user(self.super_admin)
Group.objects.create(name="normal group")
page = self.app.get(f"/admin/authentication/user/{self.user.pk}/change/")
form = page.forms["user_form"]
# when
form["groups"].select_multiple(texts=["normal group"])
response = form.submit("save")
# then
self.assertEqual(response.status_code, 302)
self.user.refresh_from_db()
self.assertIn("normal group", self.user.groups.values_list("name", flat=True))
class TestMakeServicesHooksActions(TestCaseWithTestData):
class MyServicesHookTypeA(ServicesHook):

View File

@@ -55,7 +55,6 @@ TEST_VERSION = '2.6.5'
class TestStatusOverviewTag(TestCase):
@patch(MODULE_PATH + '.admin_status.__version__', TEST_VERSION)
@patch(MODULE_PATH + '.admin_status._fetch_celery_queue_length')
@patch(MODULE_PATH + '.admin_status._current_version_summary')
@@ -66,6 +65,7 @@ class TestStatusOverviewTag(TestCase):
mock_current_version_info,
mock_fetch_celery_queue_length
):
# given
notifications = {
'notifications': GITHUB_NOTIFICATION_ISSUES[:5]
}
@@ -83,22 +83,20 @@ class TestStatusOverviewTag(TestCase):
}
mock_current_version_info.return_value = version_info
mock_fetch_celery_queue_length.return_value = 3
# when
result = status_overview()
expected = {
'notifications': GITHUB_NOTIFICATION_ISSUES[:5],
'latest_major': True,
'latest_minor': True,
'latest_patch': True,
'latest_beta': False,
'current_version': TEST_VERSION,
'latest_major_version': '2.4.5',
'latest_minor_version': '2.4.0',
'latest_patch_version': '2.4.5',
'latest_beta_version': '2.4.4a1',
'task_queue_length': 3,
}
self.assertEqual(result, expected)
# then
self.assertEqual(result["notifications"], GITHUB_NOTIFICATION_ISSUES[:5])
self.assertTrue(result["latest_major"])
self.assertTrue(result["latest_minor"])
self.assertTrue(result["latest_patch"])
self.assertFalse(result["latest_beta"])
self.assertEqual(result["current_version"], TEST_VERSION)
self.assertEqual(result["latest_major_version"], '2.4.5')
self.assertEqual(result["latest_minor_version"], '2.4.0')
self.assertEqual(result["latest_patch_version"], '2.4.5')
self.assertEqual(result["latest_beta_version"], '2.4.4a1')
self.assertEqual(result["task_queue_length"], 3)
class TestNotifications(TestCase):

View File

@@ -193,6 +193,8 @@
"columnDefs": [
{ "sortable": false, "targets": [1] },
],
"stateSave": true,
"stateDuration": 0
});
$('#table-members').DataTable({
"columnDefs": [
@@ -200,6 +202,8 @@
{ "sortable": false, "targets": [0, 2] },
],
"order": [[ 1, "asc" ]],
"stateSave": true,
"stateDuration": 0
});
$('#table-unregistered').DataTable({
"columnDefs": [
@@ -207,6 +211,8 @@
{ "sortable": false, "targets": [0, 2] },
],
"order": [[ 1, "asc" ]],
"stateSave": true,
"stateDuration": 0
});
});

View File

@@ -43,6 +43,9 @@
{% endblock %}
{% block extra_script %}
$(document).ready(function(){
$('#table-search').DataTable();
$('#table-search').DataTable({
"stateSave": true,
"stateDuration": 0
});
});
{% endblock %}

View File

@@ -1,13 +1,27 @@
from django.db import models
import logging
from typing import Union
from .managers import EveCharacterManager, EveCharacterProviderManager
from .managers import EveCorporationManager, EveCorporationProviderManager
from .managers import EveAllianceManager, EveAllianceProviderManager
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
from esi.models import Token
from allianceauth.notifications import notify
from . import providers
from .evelinks import eveimageserver
from .managers import (
EveAllianceManager,
EveAllianceProviderManager,
EveCharacterManager,
EveCharacterProviderManager,
EveCorporationManager,
EveCorporationProviderManager,
)
logger = logging.getLogger(__name__)
_DEFAULT_IMAGE_SIZE = 32
DOOMHEIM_CORPORATION_ID = 1000001
class EveFactionInfo(models.Model):
@@ -68,13 +82,12 @@ class EveAllianceInfo(models.Model):
for corp_id in alliance.corp_ids:
if not EveCorporationInfo.objects.filter(corporation_id=corp_id).exists():
EveCorporationInfo.objects.create_corporation(corp_id)
EveCorporationInfo.objects.filter(
corporation_id__in=alliance.corp_ids).update(alliance=self
EveCorporationInfo.objects.filter(corporation_id__in=alliance.corp_ids).update(
alliance=self
)
EveCorporationInfo.objects\
.filter(alliance=self)\
.exclude(corporation_id__in=alliance.corp_ids)\
.update(alliance=None)
EveCorporationInfo.objects.filter(alliance=self).exclude(
corporation_id__in=alliance.corp_ids
).update(alliance=None)
def update_alliance(self, alliance: providers.Alliance = None):
if alliance is None:
@@ -182,6 +195,7 @@ class EveCorporationInfo(models.Model):
class EveCharacter(models.Model):
"""Character in Eve Online"""
character_id = models.PositiveIntegerField(unique=True)
character_name = models.CharField(max_length=254, unique=True)
corporation_id = models.PositiveIntegerField()
@@ -198,12 +212,20 @@ class EveCharacter(models.Model):
class Meta:
indexes = [
models.Index(fields=['corporation_id',]),
models.Index(fields=['alliance_id',]),
models.Index(fields=['corporation_name',]),
models.Index(fields=['alliance_name',]),
models.Index(fields=['faction_id',]),
]
models.Index(fields=['corporation_id',]),
models.Index(fields=['alliance_id',]),
models.Index(fields=['corporation_name',]),
models.Index(fields=['alliance_name',]),
models.Index(fields=['faction_id',]),
]
def __str__(self):
return self.character_name
@property
def is_biomassed(self) -> bool:
"""Whether this character is dead or not."""
return self.corporation_id == DOOMHEIM_CORPORATION_ID
@property
def alliance(self) -> Union[EveAllianceInfo, None]:
@@ -249,10 +271,36 @@ class EveCharacter(models.Model):
self.faction_id = character.faction.id
self.faction_name = character.faction.name
self.save()
if self.is_biomassed:
self._remove_tokens_of_biomassed_character()
return self
def __str__(self):
return self.character_name
def _remove_tokens_of_biomassed_character(self) -> None:
"""Remove tokens of this biomassed character."""
try:
user = self.character_ownership.user
except ObjectDoesNotExist:
return
tokens_to_delete = Token.objects.filter(character_id=self.character_id)
tokens_count = tokens_to_delete.count()
if not tokens_count:
return
tokens_to_delete.delete()
logger.info(
"%d tokens from user %s for biomassed character %s [id:%s] deleted.",
tokens_count,
user,
self,
self.character_id,
)
notify(
user=user,
title=f"Character {self} biomassed",
message=(
f"Your former character {self} has been biomassed "
"and has been removed from the list of your alts."
)
)
@staticmethod
def generic_portrait_url(
@@ -336,7 +384,6 @@ class EveCharacter(models.Model):
"""image URL for alliance of this character or empty string"""
return self.alliance_logo_url(256)
def faction_logo_url(self, size=_DEFAULT_IMAGE_SIZE) -> str:
"""image URL for alliance of this character or empty string"""
if self.faction_id:

View File

@@ -170,7 +170,7 @@ class EveProvider:
"""
:return: an ItemType object for the given ID
"""
raise NotImplemented()
raise NotImplementedError()
class EveSwaggerProvider(EveProvider):
@@ -207,7 +207,8 @@ class EveSwaggerProvider(EveProvider):
def __str__(self):
return 'esi'
def get_alliance(self, alliance_id):
def get_alliance(self, alliance_id: int) -> Alliance:
"""Fetch alliance from ESI."""
try:
data = self.client.Alliance.get_alliances_alliance_id(alliance_id=alliance_id).result()
corps = self.client.Alliance.get_alliances_alliance_id_corporations(alliance_id=alliance_id).result()
@@ -223,7 +224,8 @@ class EveSwaggerProvider(EveProvider):
except HTTPNotFound:
raise ObjectNotFound(alliance_id, 'alliance')
def get_corp(self, corp_id):
def get_corp(self, corp_id: int) -> Corporation:
"""Fetch corporation from ESI."""
try:
data = self.client.Corporation.get_corporations_corporation_id(corporation_id=corp_id).result()
model = Corporation(
@@ -239,29 +241,43 @@ class EveSwaggerProvider(EveProvider):
except HTTPNotFound:
raise ObjectNotFound(corp_id, 'corporation')
def get_character(self, character_id):
def get_character(self, character_id: int) -> Character:
"""Fetch character from ESI."""
try:
data = self.client.Character.get_characters_character_id(character_id=character_id).result()
character_name = self._fetch_character_name(character_id)
affiliation = self.client.Character.post_characters_affiliation(characters=[character_id]).result()[0]
model = Character(
id=character_id,
name=data['name'],
name=character_name,
corp_id=affiliation['corporation_id'],
alliance_id=affiliation['alliance_id'] if 'alliance_id' in affiliation else None,
faction_id=affiliation['faction_id'] if 'faction_id' in affiliation else None,
)
return model
except (HTTPNotFound, HTTPUnprocessableEntity):
except (HTTPNotFound, HTTPUnprocessableEntity, ObjectNotFound):
raise ObjectNotFound(character_id, 'character')
def _fetch_character_name(self, character_id: int) -> str:
"""Fetch character name from ESI."""
data = self.client.Universe.post_universe_names(ids=[character_id]).result()
character = data.pop() if data else None
if (
not character
or character["category"] != "character"
or character["id"] != character_id
):
raise ObjectNotFound(character_id, 'character')
return character["name"]
def get_all_factions(self):
"""Fetch all factions from ESI."""
if not self._faction_list:
self._faction_list = self.client.Universe.get_universe_factions().result()
return self._faction_list
def get_faction(self, faction_id):
faction_id=int(faction_id)
def get_faction(self, faction_id: int):
"""Fetch faction from ESI."""
faction_id = int(faction_id)
try:
if not self._faction_list:
_ = self.get_all_factions()
@@ -273,7 +289,8 @@ class EveSwaggerProvider(EveProvider):
except (HTTPNotFound, HTTPUnprocessableEntity, KeyError):
raise ObjectNotFound(faction_id, 'faction')
def get_itemtype(self, type_id):
def get_itemtype(self, type_id: int) -> ItemType:
"""Fetch inventory item from ESI."""
try:
data = self.client.Universe.get_universe_types_type_id(type_id=type_id).result()
return ItemType(id=type_id, name=data['name'])

View File

@@ -1,12 +1,11 @@
import logging
from celery import shared_task
from .models import EveAllianceInfo
from .models import EveCharacter
from .models import EveCorporationInfo
from .models import EveAllianceInfo, EveCharacter, EveCorporationInfo
from . import providers
logger = logging.getLogger(__name__)
TASK_PRIORITY = 7
@@ -32,8 +31,8 @@ def update_alliance(alliance_id):
@shared_task
def update_character(character_id):
"""Update given character from ESI"""
def update_character(character_id: int) -> None:
"""Update given character from ESI."""
EveCharacter.objects.update_character(character_id)
@@ -65,17 +64,17 @@ def update_character_chunk(character_ids_chunk: list):
.post_characters_affiliation(characters=character_ids_chunk).result()
character_names = providers.provider.client.Universe\
.post_universe_names(ids=character_ids_chunk).result()
except:
except OSError:
logger.info("Failed to bulk update characters. Attempting single updates")
for character_id in character_ids_chunk:
update_character.apply_async(
args=[character_id], priority=TASK_PRIORITY
)
args=[character_id], priority=TASK_PRIORITY
)
return
affiliations = {
affiliation.get('character_id'): affiliation
for affiliation in affiliations_raw
affiliation.get('character_id'): affiliation
for affiliation in affiliations_raw
}
# add character names to affiliations
for character in character_names:
@@ -108,5 +107,5 @@ def update_character_chunk(character_ids_chunk: list):
if corp_changed or alliance_changed or name_changed:
update_character.apply_async(
args=[character.get('character_id')], priority=TASK_PRIORITY
)
args=[character.get('character_id')], priority=TASK_PRIORITY
)

View File

@@ -0,0 +1,168 @@
from bravado.exception import HTTPNotFound
class BravadoResponseStub:
"""Stub for IncomingResponse in bravado, e.g. for HTTPError exceptions"""
def __init__(
self, status_code, reason="", text="", headers=None, raw_bytes=None
) -> None:
self.reason = reason
self.status_code = status_code
self.text = text
self.headers = headers if headers else dict()
self.raw_bytes = raw_bytes
def __str__(self):
return f"{self.status_code} {self.reason}"
class BravadoOperationStub:
"""Stub to simulate the operation object return from bravado via django-esi"""
class RequestConfig:
def __init__(self, also_return_response):
self.also_return_response = also_return_response
class ResponseStub:
def __init__(self, headers):
self.headers = headers
def __init__(self, data, headers: dict = None, also_return_response: bool = False):
self._data = data
self._headers = headers if headers else {"x-pages": 1}
self.request_config = BravadoOperationStub.RequestConfig(also_return_response)
def result(self, **kwargs):
if self.request_config.also_return_response:
return [self._data, self.ResponseStub(self._headers)]
else:
return self._data
def results(self, **kwargs):
return self.result(**kwargs)
class EsiClientStub:
"""Stub for an ESI client."""
class Alliance:
@staticmethod
def get_alliances_alliance_id(alliance_id):
data = {
3001: {
"name": "Wayne Enterprises",
"ticker": "WYE",
"executor_corporation_id": 2001
}
}
try:
return BravadoOperationStub(data[int(alliance_id)])
except KeyError:
response = BravadoResponseStub(
404, f"Alliance with ID {alliance_id} not found"
)
raise HTTPNotFound(response)
@staticmethod
def get_alliances_alliance_id_corporations(alliance_id):
data = [2001, 2002, 2003]
return BravadoOperationStub(data)
class Character:
@staticmethod
def get_characters_character_id(character_id):
data = {
1001: {
"corporation_id": 2001,
"name": "Bruce Wayne",
},
1002: {
"corporation_id": 2001,
"name": "Peter Parker",
},
1011: {
"corporation_id": 2011,
"name": "Lex Luthor",
}
}
try:
return BravadoOperationStub(data[int(character_id)])
except KeyError:
response = BravadoResponseStub(
404, f"Character with ID {character_id} not found"
)
raise HTTPNotFound(response)
@staticmethod
def post_characters_affiliation(characters: list):
data = [
{'character_id': 1001, 'corporation_id': 2001, 'alliance_id': 3001},
{'character_id': 1002, 'corporation_id': 2001, 'alliance_id': 3001},
{'character_id': 1011, 'corporation_id': 2011},
{'character_id': 1666, 'corporation_id': 1000001},
]
return BravadoOperationStub(
[x for x in data if x['character_id'] in characters]
)
class Corporation:
@staticmethod
def get_corporations_corporation_id(corporation_id):
data = {
2001: {
"ceo_id": 1091,
"member_count": 10,
"name": "Wayne Technologies",
"ticker": "WTE",
"alliance_id": 3001
},
2002: {
"ceo_id": 1092,
"member_count": 10,
"name": "Wayne Food",
"ticker": "WFO",
"alliance_id": 3001
},
2003: {
"ceo_id": 1093,
"member_count": 10,
"name": "Wayne Energy",
"ticker": "WEG",
"alliance_id": 3001
},
2011: {
"ceo_id": 1,
"member_count": 3,
"name": "LexCorp",
"ticker": "LC",
},
1000001: {
"ceo_id": 3000001,
"creator_id": 1,
"description": "The internal corporation used for characters in graveyard.",
"member_count": 6329026,
"name": "Doomheim",
"ticker": "666",
}
}
try:
return BravadoOperationStub(data[int(corporation_id)])
except KeyError:
response = BravadoResponseStub(
404, f"Corporation with ID {corporation_id} not found"
)
raise HTTPNotFound(response)
class Universe:
@staticmethod
def post_universe_names(ids: list):
data = [
{"category": "character", "id": 1001, "name": "Bruce Wayne"},
{"category": "character", "id": 1002, "name": "Peter Parker"},
{"category": "character", "id": 1011, "name": "Lex Luthor"},
{"category": "character", "id": 1666, "name": "Hal Jordan"},
{"category": "corporation", "id": 2001, "name": "Wayne Technologies"},
{"category": "corporation","id": 2002, "name": "Wayne Food"},
{"category": "corporation","id": 1000001, "name": "Doomheim"},
]
return BravadoOperationStub([x for x in data if x['id'] in ids])

View File

@@ -1,12 +1,15 @@
from unittest.mock import Mock, patch
from django.core.exceptions import ObjectDoesNotExist
from django.test import TestCase
from esi.models import Token
from allianceauth.tests.auth_utils import AuthUtils
from ..models import (
EveCharacter, EveCorporationInfo, EveAllianceInfo, EveFactionInfo
)
from ..providers import Alliance, Corporation, Character
from ..evelinks import eveimageserver
from ..models import EveAllianceInfo, EveCharacter, EveCorporationInfo, EveFactionInfo
from ..providers import Alliance, Character, Corporation
from .esi_client_stub import EsiClientStub
class EveCharacterTestCase(TestCase):
@@ -402,8 +405,8 @@ class EveAllianceTestCase(TestCase):
my_alliance.save()
my_alliance.populate_alliance()
for corporation in EveCorporationInfo.objects\
.filter(corporation_id__in=[2001, 2002]
for corporation in (
EveCorporationInfo.objects.filter(corporation_id__in=[2001, 2002])
):
self.assertEqual(corporation.alliance, my_alliance)
@@ -587,3 +590,98 @@ class EveCorporationTestCase(TestCase):
self.my_corp.logo_url_256,
'https://images.evetech.net/corporations/2001/logo?size=256'
)
@patch('allianceauth.eveonline.providers.esi_client_factory')
@patch("allianceauth.eveonline.models.notify")
class TestCharacterUpdate(TestCase):
def test_should_update_normal_character(self, mock_notify, mock_esi_client_factory):
# given
mock_esi_client_factory.return_value = EsiClientStub()
my_character = EveCharacter.objects.create(
character_id=1001,
character_name="not my name",
corporation_id=2002,
corporation_name="Wayne Food",
corporation_ticker="WYF",
alliance_id=None
)
# when
my_character.update_character()
# then
my_character.refresh_from_db()
self.assertEqual(my_character.character_name, "Bruce Wayne")
self.assertEqual(my_character.corporation_id, 2001)
self.assertEqual(my_character.corporation_name, "Wayne Technologies")
self.assertEqual(my_character.corporation_ticker, "WTE")
self.assertEqual(my_character.alliance_id, 3001)
self.assertEqual(my_character.alliance_name, "Wayne Enterprises")
self.assertEqual(my_character.alliance_ticker, "WYE")
self.assertFalse(mock_notify.called)
def test_should_update_dead_character_with_owner(
self, mock_notify, mock_esi_client_factory
):
# given
mock_esi_client_factory.return_value = EsiClientStub()
character_1666 = EveCharacter.objects.create(
character_id=1666,
character_name="Hal Jordan",
corporation_id=2002,
corporation_name="Wayne Food",
corporation_ticker="WYF",
alliance_id=None
)
user = AuthUtils.create_user("Bruce Wayne")
token_1666 = Token.objects.create(
user=user,
character_id=character_1666.character_id,
character_name=character_1666.character_name,
character_owner_hash="ABC123-1666",
)
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WYT",
alliance_id=None
)
token_1001 = Token.objects.create(
user=user,
character_id=character_1001.character_id,
character_name=character_1001.character_name,
character_owner_hash="ABC123-1001",
)
# when
character_1666.update_character()
# then
character_1666.refresh_from_db()
self.assertTrue(character_1666.is_biomassed)
self.assertNotIn(token_1666, user.token_set.all())
self.assertIn(token_1001, user.token_set.all())
with self.assertRaises(ObjectDoesNotExist):
self.assertTrue(character_1666.character_ownership)
user.profile.refresh_from_db()
self.assertIsNone(user.profile.main_character)
self.assertTrue(mock_notify.called)
def test_should_handle_dead_character_without_owner(
self, mock_notify, mock_esi_client_factory
):
# given
mock_esi_client_factory.return_value = EsiClientStub()
character_1666 = EveCharacter.objects.create(
character_id=1666,
character_name="Hal Jordan",
corporation_id=1011,
corporation_name="LexCorp",
corporation_ticker='LC',
alliance_id=None
)
# when
character_1666.update_character()
# then
character_1666.refresh_from_db()
self.assertTrue(character_1666.is_biomassed)
self.assertFalse(mock_notify.called)

View File

@@ -7,6 +7,7 @@ from jsonschema.exceptions import RefResolutionError
from django.test import TestCase
from . import set_logger
from .esi_client_stub import EsiClientStub
from ..providers import (
ObjectNotFound,
Entity,
@@ -632,13 +633,7 @@ class TestEveSwaggerProvider(TestCase):
@patch(MODULE_PATH + '.esi_client_factory')
def test_get_character(self, mock_esi_client_factory):
mock_esi_client_factory.return_value \
.Character.get_characters_character_id \
= TestEveSwaggerProvider.esi_get_characters_character_id
mock_esi_client_factory.return_value \
.Character.post_characters_affiliation \
= TestEveSwaggerProvider.esi_post_characters_affiliation
mock_esi_client_factory.return_value = EsiClientStub()
my_provider = EveSwaggerProvider()
# character with alliance
@@ -649,8 +644,8 @@ class TestEveSwaggerProvider(TestCase):
self.assertEqual(my_character.alliance_id, 3001)
# character wo/ alliance
my_character = my_provider.get_character(1002)
self.assertEqual(my_character.id, 1002)
my_character = my_provider.get_character(1011)
self.assertEqual(my_character.id, 1011)
self.assertEqual(my_character.alliance_id, None)
# character not found

View File

@@ -1,245 +1,271 @@
from unittest.mock import patch, Mock
from unittest.mock import patch
from django.test import TestCase
from django.test import TestCase, TransactionTestCase, override_settings
from ..models import EveCharacter, EveCorporationInfo, EveAllianceInfo
from ..models import EveAllianceInfo, EveCharacter, EveCorporationInfo
from ..tasks import (
run_model_update,
update_alliance,
update_corp,
update_character,
run_model_update
update_character_chunk,
update_corp,
)
from .esi_client_stub import EsiClientStub
class TestTasks(TestCase):
@patch('allianceauth.eveonline.tasks.EveCorporationInfo')
def test_update_corp(self, mock_EveCorporationInfo):
update_corp(42)
self.assertEqual(
mock_EveCorporationInfo.objects.update_corporation.call_count, 1
)
self.assertEqual(
mock_EveCorporationInfo.objects.update_corporation.call_args[0][0], 42
@patch('allianceauth.eveonline.providers.esi_client_factory')
class TestUpdateTasks(TestCase):
def test_should_update_alliance(self, mock_esi_client_factory):
# given
mock_esi_client_factory.return_value = EsiClientStub()
my_alliance = EveAllianceInfo.objects.create(
alliance_id=3001,
alliance_name="Wayne Enterprises",
alliance_ticker="WYE",
executor_corp_id=2003
)
# when
update_alliance(my_alliance.alliance_id)
# then
my_alliance.refresh_from_db()
self.assertEqual(my_alliance.executor_corp_id, 2001)
@patch('allianceauth.eveonline.tasks.EveAllianceInfo')
def test_update_alliance(self, mock_EveAllianceInfo):
update_alliance(42)
self.assertEqual(
mock_EveAllianceInfo.objects.update_alliance.call_args[0][0], 42
)
self.assertEqual(
mock_EveAllianceInfo.objects
.update_alliance.return_value.populate_alliance.call_count, 1
def test_should_update_character(self, mock_esi_client_factory):
# given
mock_esi_client_factory.return_value = EsiClientStub()
my_character = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2002,
corporation_name="Wayne Food",
corporation_ticker="WYF",
alliance_id=None
)
# when
update_character(my_character.character_id)
# then
my_character.refresh_from_db()
self.assertEqual(my_character.corporation_id, 2001)
@patch('allianceauth.eveonline.tasks.EveCharacter')
def test_update_character(self, mock_EveCharacter):
update_character(42)
self.assertEqual(
mock_EveCharacter.objects.update_character.call_count, 1
def test_should_update_corp(self, mock_esi_client_factory):
# given
mock_esi_client_factory.return_value = EsiClientStub()
EveAllianceInfo.objects.create(
alliance_id=3001,
alliance_name="Wayne Enterprises",
alliance_ticker="WYE",
executor_corp_id=2003
)
self.assertEqual(
mock_EveCharacter.objects.update_character.call_args[0][0], 42
my_corporation = EveCorporationInfo.objects.create(
corporation_id=2003,
corporation_name="Wayne Food",
corporation_ticker="WFO",
member_count=1,
alliance=None,
ceo_id=1999
)
# when
update_corp(my_corporation.corporation_id)
# then
my_corporation.refresh_from_db()
self.assertEqual(my_corporation.alliance.alliance_id, 3001)
# @patch('allianceauth.eveonline.tasks.EveCharacter')
# def test_update_character(self, mock_EveCharacter):
# update_character(42)
# self.assertEqual(
# mock_EveCharacter.objects.update_character.call_count, 1
# )
# self.assertEqual(
# mock_EveCharacter.objects.update_character.call_args[0][0], 42
# )
@patch('allianceauth.eveonline.tasks.update_character')
@patch('allianceauth.eveonline.tasks.update_alliance')
@patch('allianceauth.eveonline.tasks.update_corp')
@patch('allianceauth.eveonline.providers.provider')
@override_settings(CELERY_ALWAYS_EAGER=True)
@patch('allianceauth.eveonline.providers.esi_client_factory')
@patch('allianceauth.eveonline.tasks.providers')
@patch('allianceauth.eveonline.tasks.CHUNK_SIZE', 2)
class TestRunModelUpdate(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
EveCorporationInfo.objects.all().delete()
EveAllianceInfo.objects.all().delete()
EveCharacter.objects.all().delete()
class TestRunModelUpdate(TransactionTestCase):
def test_should_run_updates(self, mock_providers, mock_esi_client_factory):
# given
mock_providers.provider.client = EsiClientStub()
mock_esi_client_factory.return_value = EsiClientStub()
EveCorporationInfo.objects.create(
corporation_id=2345,
corporation_name='corp.name',
corporation_ticker='c.c.t',
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
member_count=10,
alliance=None,
)
EveAllianceInfo.objects.create(
alliance_id=3456,
alliance_name='alliance.name',
alliance_ticker='a.t',
executor_corp_id=5,
alliance_3001 = EveAllianceInfo.objects.create(
alliance_id=3001,
alliance_name="Wayne Enterprises",
alliance_ticker="WYE",
executor_corp_id=2003
)
EveCharacter.objects.create(
character_id=1,
character_name='character.name1',
corporation_id=2345,
corporation_name='character.corp.name',
corporation_ticker='c.c.t', # max 5 chars
corporation_2003 = EveCorporationInfo.objects.create(
corporation_id=2003,
corporation_name="Wayne Energy",
corporation_ticker="WEG",
member_count=99,
alliance=None,
)
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2002,
corporation_name="Wayne Food",
corporation_ticker="WYF",
alliance_id=None
)
EveCharacter.objects.create(
character_id=2,
character_name='character.name2',
corporation_id=9876,
corporation_name='character.corp.name',
corporation_ticker='c.c.t', # max 5 chars
alliance_id=3456,
alliance_name='character.alliance.name',
)
EveCharacter.objects.create(
character_id=3,
character_name='character.name3',
corporation_id=9876,
corporation_name='character.corp.name',
corporation_ticker='c.c.t', # max 5 chars
alliance_id=3456,
alliance_name='character.alliance.name',
)
EveCharacter.objects.create(
character_id=4,
character_name='character.name4',
corporation_id=9876,
corporation_name='character.corp.name',
corporation_ticker='c.c.t', # max 5 chars
alliance_id=3456,
alliance_name='character.alliance.name',
)
"""
EveCharacter.objects.create(
character_id=5,
character_name='character.name5',
corporation_id=9876,
corporation_name='character.corp.name',
corporation_ticker='c.c.t', # max 5 chars
alliance_id=3456,
alliance_name='character.alliance.name',
)
"""
def setUp(self):
self.affiliations = [
{'character_id': 1, 'corporation_id': 5},
{'character_id': 2, 'corporation_id': 9876, 'alliance_id': 3456},
{'character_id': 3, 'corporation_id': 9876, 'alliance_id': 7456},
{'character_id': 4, 'corporation_id': 9876, 'alliance_id': 3456}
]
self.names = [
{'id': 1, 'name': 'character.name1'},
{'id': 2, 'name': 'character.name2'},
{'id': 3, 'name': 'character.name3'},
{'id': 4, 'name': 'character.name4_new'}
]
def test_normal_run(
self,
mock_provider,
mock_update_corp,
mock_update_alliance,
mock_update_character,
):
def get_affiliations(characters: list):
response = [x for x in self.affiliations if x['character_id'] in characters]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
def get_names(ids: list):
response = [x for x in self.names if x['id'] in ids]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
mock_provider.client.Character.post_characters_affiliation.side_effect \
= get_affiliations
mock_provider.client.Universe.post_universe_names.side_effect = get_names
# when
run_model_update()
# then
character_1001.refresh_from_db()
self.assertEqual(
mock_provider.client.Character.post_characters_affiliation.call_count, 2
character_1001.corporation_id, 2001 # char has new corp
)
corporation_2003.refresh_from_db()
self.assertEqual(
mock_provider.client.Universe.post_universe_names.call_count, 2
corporation_2003.alliance.alliance_id, 3001 # corp has new alliance
)
alliance_3001.refresh_from_db()
self.assertEqual(
alliance_3001.executor_corp_id, 2001 # alliance has been updated
)
# character 1 has changed corp
# character 2 no change
# character 3 has changed alliance
# character 4 has changed name
self.assertEqual(mock_update_corp.apply_async.call_count, 1)
self.assertEqual(
int(mock_update_corp.apply_async.call_args[1]['args'][0]), 2345
)
self.assertEqual(mock_update_alliance.apply_async.call_count, 1)
self.assertEqual(
int(mock_update_alliance.apply_async.call_args[1]['args'][0]), 3456
)
characters_updated = {
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
@override_settings(CELERY_ALWAYS_EAGER=True)
@patch('allianceauth.eveonline.tasks.update_character', wraps=update_character)
@patch('allianceauth.eveonline.providers.esi_client_factory')
@patch('allianceauth.eveonline.tasks.providers')
@patch('allianceauth.eveonline.tasks.CHUNK_SIZE', 2)
class TestUpdateCharacterChunk(TestCase):
@staticmethod
def _updated_character_ids(spy_update_character) -> set:
"""Character IDs passed to update_character task for update."""
return {
x[1]["args"][0] for x in spy_update_character.apply_async.call_args_list
}
excepted = {1, 3, 4}
self.assertSetEqual(characters_updated, excepted)
def test_ignore_character_not_in_affiliations(
self,
mock_provider,
mock_update_corp,
mock_update_alliance,
mock_update_character,
def test_should_update_corp_change(
self, mock_providers, mock_esi_client_factory, spy_update_character
):
def get_affiliations(characters: list):
response = [x for x in self.affiliations if x['character_id'] in characters]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
# given
mock_providers.provider.client = EsiClientStub()
mock_esi_client_factory.return_value = EsiClientStub()
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2003,
corporation_name="Wayne Energy",
corporation_ticker="WEG",
alliance_id=3001,
alliance_name="Wayne Enterprises",
alliance_ticker="WYE",
)
character_1002 = EveCharacter.objects.create(
character_id=1002,
character_name="Peter Parker",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
alliance_id=3001,
alliance_name="Wayne Enterprises",
alliance_ticker="WYE",
)
# when
update_character_chunk([
character_1001.character_id, character_1002.character_id
])
# then
character_1001.refresh_from_db()
self.assertEqual(character_1001.corporation_id, 2001)
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
def get_names(ids: list):
response = [x for x in self.names if x['id'] in ids]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
del self.affiliations[0]
mock_provider.client.Character.post_characters_affiliation.side_effect \
= get_affiliations
mock_provider.client.Universe.post_universe_names.side_effect = get_names
run_model_update()
characters_updated = {
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
}
excepted = {3, 4}
self.assertSetEqual(characters_updated, excepted)
def test_ignore_character_not_in_names(
self,
mock_provider,
mock_update_corp,
mock_update_alliance,
mock_update_character,
def test_should_update_name_change(
self, mock_providers, mock_esi_client_factory, spy_update_character
):
def get_affiliations(characters: list):
response = [x for x in self.affiliations if x['character_id'] in characters]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
# given
mock_providers.provider.client = EsiClientStub()
mock_esi_client_factory.return_value = EsiClientStub()
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Batman",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
alliance_id=3001,
alliance_name="Wayne Technologies",
alliance_ticker="WYT",
)
# when
update_character_chunk([character_1001.character_id])
# then
character_1001.refresh_from_db()
self.assertEqual(character_1001.character_name, "Bruce Wayne")
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
def get_names(ids: list):
response = [x for x in self.names if x['id'] in ids]
mock_operator = Mock(**{'result.return_value': response})
return mock_operator
def test_should_update_alliance_change(
self, mock_providers, mock_esi_client_factory, spy_update_character
):
# given
mock_providers.provider.client = EsiClientStub()
mock_esi_client_factory.return_value = EsiClientStub()
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
alliance_id=None,
)
# when
update_character_chunk([character_1001.character_id])
# then
character_1001.refresh_from_db()
self.assertEqual(character_1001.alliance_id, 3001)
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})
del self.names[3]
def test_should_not_update_when_not_changed(
self, mock_providers, mock_esi_client_factory, spy_update_character
):
# given
mock_providers.provider.client = EsiClientStub()
mock_esi_client_factory.return_value = EsiClientStub()
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
alliance_id=3001,
alliance_name="Wayne Technologies",
alliance_ticker="WYT",
)
# when
update_character_chunk([character_1001.character_id])
# then
self.assertSetEqual(self._updated_character_ids(spy_update_character), set())
mock_provider.client.Character.post_characters_affiliation.side_effect \
= get_affiliations
mock_provider.client.Universe.post_universe_names.side_effect = get_names
run_model_update()
characters_updated = {
x[1]['args'][0] for x in mock_update_character.apply_async.call_args_list
}
excepted = {1, 3}
self.assertSetEqual(characters_updated, excepted)
def test_should_fall_back_to_single_updates_when_bulk_update_failed(
self, mock_providers, mock_esi_client_factory, spy_update_character
):
# given
mock_providers.provider.client.Character.post_characters_affiliation\
.side_effect = OSError
mock_esi_client_factory.return_value = EsiClientStub()
character_1001 = EveCharacter.objects.create(
character_id=1001,
character_name="Bruce Wayne",
corporation_id=2001,
corporation_name="Wayne Technologies",
corporation_ticker="WTE",
alliance_id=3001,
alliance_name="Wayne Technologies",
alliance_ticker="WYT",
)
# when
update_character_chunk([character_1001.character_id])
# then
self.assertSetEqual(self._updated_character_ids(spy_update_character), {1001})

View File

@@ -212,7 +212,14 @@ def fatlink_monthly_personal_statistics_view(request, year, month, char_id=None)
start_of_previous_month = first_day_of_previous_month(year, month)
if request.user.has_perm('auth.fleetactivitytracking_statistics') and char_id:
user = EveCharacter.objects.get(character_id=char_id).user
try:
user = EveCharacter.objects.get(character_id=char_id).character_ownership.user
except EveCharacter.DoesNotExist:
messages.error(request, _('Character does not exist'))
return redirect('fatlink:view')
except AttributeError:
messages.error(request, _('User does not exist'))
return redirect('fatlink:view')
else:
user = request.user
logger.debug(f"Personal monthly statistics view for user {user} called by {request.user}")

View File

@@ -1,19 +1,21 @@
from django import forms
from django.apps import apps
from django.contrib.auth.models import Permission
from django.contrib import admin
from django.contrib.auth.models import Group as BaseGroup, User
from django.core.exceptions import ValidationError
from django.db.models import Count
from django.db.models.functions import Lower
from django.db.models.signals import pre_save, post_save, pre_delete, \
post_delete, m2m_changed
from django.dispatch import receiver
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from .models import AuthGroup, ReservedGroupName
from .models import GroupRequest
from django.contrib.auth.models import Group as BaseGroup, Permission, User
from django.db.models import Count, Exists, OuterRef
from django.db.models.functions import Lower
from django.db.models.signals import (
m2m_changed,
post_delete,
post_save,
pre_delete,
pre_save
)
from django.dispatch import receiver
from .forms import GroupAdminForm, ReservedGroupNameAdminForm
from .models import AuthGroup, GroupRequest, ReservedGroupName
from .tasks import remove_users_not_matching_states_from_group
if 'eve_autogroups' in apps.app_configs:
_has_auto_groups = True
@@ -28,10 +30,12 @@ class AuthGroupInlineAdmin(admin.StackedInline):
'description',
'group_leaders',
'group_leader_groups',
'states', 'internal',
'states',
'internal',
'hidden',
'open',
'public'
'public',
'restricted',
)
verbose_name_plural = 'Auth Settings'
verbose_name = ''
@@ -50,6 +54,11 @@ class AuthGroupInlineAdmin(admin.StackedInline):
def has_change_permission(self, request, obj=None):
return request.user.has_perm('auth.change_group')
def get_readonly_fields(self, request, obj=None):
if not request.user.is_superuser:
return self.readonly_fields + ("restricted",)
return self.readonly_fields
if _has_auto_groups:
class IsAutoGroupFilter(admin.SimpleListFilter):
@@ -96,27 +105,15 @@ class HasLeaderFilter(admin.SimpleListFilter):
return queryset
class GroupAdminForm(forms.ModelForm):
def clean_name(self):
my_name = self.cleaned_data['name']
if ReservedGroupName.objects.filter(name__iexact=my_name).exists():
raise ValidationError(
_("This name has been reserved and can not be used for groups."),
code='reserved_name'
)
return my_name
class GroupAdmin(admin.ModelAdmin):
form = GroupAdminForm
list_select_related = ('authgroup',)
ordering = ('name',)
list_display = (
'name',
'_description',
'_properties',
'_member_count',
'has_leader'
'has_leader',
)
list_filter = [
'authgroup__internal',
@@ -132,34 +129,51 @@ class GroupAdmin(admin.ModelAdmin):
def get_queryset(self, request):
qs = super().get_queryset(request)
if _has_auto_groups:
qs = qs.prefetch_related('managedalliancegroup_set', 'managedcorpgroup_set')
qs = qs.prefetch_related('authgroup__group_leaders').select_related('authgroup')
qs = qs.annotate(
member_count=Count('user', distinct=True),
has_leader_qs = (
AuthGroup.objects.filter(group=OuterRef('pk'), group_leaders__isnull=False)
)
has_leader_groups_qs = (
AuthGroup.objects.filter(
group=OuterRef('pk'), group_leader_groups__isnull=False
)
)
qs = (
qs.select_related('authgroup')
.annotate(member_count=Count('user', distinct=True))
.annotate(has_leader=Exists(has_leader_qs))
.annotate(has_leader_groups=Exists(has_leader_groups_qs))
)
if _has_auto_groups:
is_autogroup_corp = (
Group.objects.filter(
pk=OuterRef('pk'), managedcorpgroup__isnull=False
)
)
is_autogroup_alliance = (
Group.objects.filter(
pk=OuterRef('pk'), managedalliancegroup__isnull=False
)
)
qs = (
qs.annotate(is_autogroup_corp=Exists(is_autogroup_corp))
.annotate(is_autogroup_alliance=Exists(is_autogroup_alliance))
)
return qs
def _description(self, obj):
return obj.authgroup.description
@admin.display(description='Members', ordering='member_count')
def _member_count(self, obj):
return obj.member_count
_member_count.short_description = 'Members'
_member_count.admin_order_field = 'member_count'
@admin.display(boolean=True)
def has_leader(self, obj):
return obj.authgroup.group_leaders.exists() or obj.authgroup.group_leader_groups.exists()
has_leader.boolean = True
return obj.has_leader or obj.has_leader_groups
def _properties(self, obj):
properties = list()
if _has_auto_groups and (
obj.managedalliancegroup_set.exists()
or obj.managedcorpgroup_set.exists()
):
if _has_auto_groups and (obj.is_autogroup_corp or obj.is_autogroup_alliance):
properties.append('Auto Group')
elif obj.authgroup.internal:
properties.append('Internal')
@@ -172,11 +186,10 @@ class GroupAdmin(admin.ModelAdmin):
properties.append('Public')
if not properties:
properties.append('Default')
if obj.authgroup.restricted:
properties.append('Restricted')
return properties
_properties.short_description = "properties"
filter_horizontal = ('permissions',)
inlines = (AuthGroupInlineAdmin,)
@@ -190,8 +203,15 @@ class GroupAdmin(admin.ModelAdmin):
ag_instance = inline_form.save(commit=False)
ag_instance.group = form.instance
ag_instance.save()
if ag_instance.states.exists():
remove_users_not_matching_states_from_group.delay(ag_instance.group.pk)
formset.save()
def get_readonly_fields(self, request, obj=None):
if not request.user.is_superuser:
return self.readonly_fields + ("permissions",)
return self.readonly_fields
class Group(BaseGroup):
class Meta:
@@ -216,33 +236,10 @@ class GroupRequestAdmin(admin.ModelAdmin):
'leave_request',
)
@admin.display(boolean=True, description="is leave request")
def _leave_request(self, obj) -> True:
return obj.leave_request
_leave_request.short_description = 'is leave request'
_leave_request.boolean = True
class ReservedGroupNameAdminForm(forms.ModelForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['created_by'].initial = self.current_user.username
self.fields['created_at'].initial = _("(auto)")
created_by = forms.CharField(disabled=True)
created_at = forms.CharField(disabled=True)
def clean_name(self):
my_name = self.cleaned_data['name'].lower()
if Group.objects.filter(name__iexact=my_name).exists():
raise ValidationError(
_("There already exists a group with that name."), code='already_exists'
)
return my_name
def clean_created_at(self):
return now()
@admin.register(ReservedGroupName)
class ReservedGroupNameAdmin(admin.ModelAdmin):

View File

@@ -0,0 +1,39 @@
from django import forms
from django.contrib.auth.models import Group
from django.core.exceptions import ValidationError
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from .models import ReservedGroupName
class GroupAdminForm(forms.ModelForm):
def clean_name(self):
my_name = self.cleaned_data['name']
if ReservedGroupName.objects.filter(name__iexact=my_name).exists():
raise ValidationError(
_("This name has been reserved and can not be used for groups."),
code='reserved_name'
)
return my_name
class ReservedGroupNameAdminForm(forms.ModelForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields['created_by'].initial = self.current_user.username
self.fields['created_at'].initial = _("(auto)")
created_by = forms.CharField(disabled=True)
created_at = forms.CharField(disabled=True)
def clean_name(self):
my_name = self.cleaned_data['name'].lower()
if Group.objects.filter(name__iexact=my_name).exists():
raise ValidationError(
_("There already exists a group with that name."), code='already_exists'
)
return my_name
def clean_created_at(self):
return now()

View File

@@ -0,0 +1,18 @@
# Generated by Django 3.2.10 on 2022-04-08 19:30
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('groupmanagement', '0018_reservedgroupname'),
]
operations = [
migrations.AddField(
model_name='authgroup',
name='restricted',
field=models.BooleanField(default=False, help_text='Group is restricted. This means that adding or removing users for this group requires a superuser admin.'),
),
]

View File

@@ -13,6 +13,7 @@ from allianceauth.notifications import notify
class GroupRequest(models.Model):
"""Request from a user for joining or leaving a group."""
leave_request = models.BooleanField(default=0)
user = models.ForeignKey(User, on_delete=models.CASCADE)
group = models.ForeignKey(Group, on_delete=models.CASCADE)
@@ -44,6 +45,7 @@ class GroupRequest(models.Model):
class RequestLog(models.Model):
"""Log entry about who joined and left a group and who approved it."""
request_type = models.BooleanField(null=True)
group = models.ForeignKey(Group, on_delete=models.CASCADE)
request_info = models.CharField(max_length=254)
@@ -95,6 +97,7 @@ class AuthGroup(models.Model):
Open - Users are automatically accepted into the group
Not Open - Users requests must be approved before they are added to the group
"""
group = models.OneToOneField(Group, on_delete=models.CASCADE, primary_key=True)
internal = models.BooleanField(
default=True,
@@ -126,6 +129,13 @@ class AuthGroup(models.Model):
"are no longer authenticated."
)
)
restricted = models.BooleanField(
default=False,
help_text=_(
"Group is restricted. This means that adding or removing users "
"for this group requires a superuser admin."
)
)
group_leaders = models.ManyToManyField(
User,
related_name='leads_groups',
@@ -179,12 +189,22 @@ class AuthGroup(models.Model):
| User.objects.filter(groups__in=list(self.group_leader_groups.all()))
)
def remove_users_not_matching_states(self):
"""Remove users not matching defined states from related group."""
states_qs = self.states.all()
if states_qs.exists():
states = list(states_qs)
non_compliant_users = self.group.user_set.exclude(profile__state__in=states)
for user in non_compliant_users:
self.group.user_set.remove(user)
class ReservedGroupName(models.Model):
"""Name that can not be used for groups.
This enables AA to ignore groups on other services (e.g. Discord) with that name.
"""
name = models.CharField(
_('name'),
max_length=150,

View File

@@ -0,0 +1,10 @@
from celery import shared_task
from django.contrib.auth.models import Group
@shared_task
def remove_users_not_matching_states_from_group(group_pk: int) -> None:
"""Remove users not matching defined states from related group."""
group = Group.objects.get(pk=group_pk)
group.authgroup.remove_users_not_matching_states()

View File

@@ -127,6 +127,8 @@
],
bootstrap: true
},
"stateSave": true,
"stateDuration": 0
});
});
{% endblock %}

View File

@@ -104,7 +104,9 @@
"sortable": false,
"targets": [2]
},
]
],
"stateSave": true,
"stateDuration": 0
});
});
{% endblock %}

View File

@@ -1,17 +1,21 @@
from unittest.mock import patch
from django_webtest import WebTest
from django.conf import settings
from django.contrib import admin
from django.contrib.admin.sites import AdminSite
from django.contrib.auth.models import User
from django.test import TestCase, RequestFactory, Client
from django.test import TestCase, RequestFactory, Client, override_settings
from allianceauth.authentication.models import CharacterOwnership, State
from allianceauth.eveonline.models import (
EveCharacter, EveCorporationInfo, EveAllianceInfo
)
from ..admin import HasLeaderFilter, GroupAdmin, Group
from allianceauth.tests.auth_utils import AuthUtils
from . import get_admin_change_view_url
from ..admin import HasLeaderFilter, GroupAdmin, Group
from ..models import ReservedGroupName
@@ -33,7 +37,6 @@ class MockRequest:
class TestGroupAdmin(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -233,60 +236,104 @@ class TestGroupAdmin(TestCase):
self.assertEqual(result, expected)
def test_member_count(self):
expected = 1
obj = self.modeladmin.get_queryset(MockRequest(user=self.user_1))\
.get(pk=self.group_1.pk)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
# when
result = self.modeladmin._member_count(obj)
self.assertEqual(result, expected)
# then
self.assertEqual(result, 1)
def test_has_leader_user(self):
result = self.modeladmin.has_leader(self.group_1)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
# when
result = self.modeladmin.has_leader(obj)
# then
self.assertTrue(result)
def test_has_leader_group(self):
result = self.modeladmin.has_leader(self.group_2)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_2.pk)
# when
result = self.modeladmin.has_leader(obj)
# then
self.assertTrue(result)
def test_properties_1(self):
expected = ['Default']
result = self.modeladmin._properties(self.group_1)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_1.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Default'])
def test_properties_2(self):
expected = ['Internal']
result = self.modeladmin._properties(self.group_2)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_2.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Internal'])
def test_properties_3(self):
expected = ['Hidden']
result = self.modeladmin._properties(self.group_3)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_3.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Hidden'])
def test_properties_4(self):
expected = ['Open']
result = self.modeladmin._properties(self.group_4)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_4.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Open'])
def test_properties_5(self):
expected = ['Public']
result = self.modeladmin._properties(self.group_5)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_5.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Public'])
def test_properties_6(self):
expected = ['Hidden', 'Open', 'Public']
result = self.modeladmin._properties(self.group_6)
self.assertListEqual(result, expected)
# given
request = MockRequest(user=self.user_1)
obj = self.modeladmin.get_queryset(request).get(pk=self.group_6.pk)
# when
result = self.modeladmin._properties(obj)
self.assertListEqual(result, ['Hidden', 'Open', 'Public'])
if _has_auto_groups:
@patch(MODULE_PATH + '._has_auto_groups', True)
def test_properties_7(self):
def test_should_show_autogroup_for_corporation(self):
# given
self._create_autogroups()
expected = ['Auto Group']
my_group = Group.objects\
.filter(managedcorpgroup__isnull=False)\
.first()
result = self.modeladmin._properties(my_group)
self.assertListEqual(result, expected)
request = MockRequest(user=self.user_1)
queryset = self.modeladmin.get_queryset(request)
obj = queryset.filter(managedcorpgroup__isnull=False).first()
# when
result = self.modeladmin._properties(obj)
# then
self.assertListEqual(result, ['Auto Group'])
@patch(MODULE_PATH + '._has_auto_groups', True)
def test_should_show_autogroup_for_alliance(self):
# given
self._create_autogroups()
request = MockRequest(user=self.user_1)
queryset = self.modeladmin.get_queryset(request)
obj = queryset.filter(managedalliancegroup__isnull=False).first()
# when
result = self.modeladmin._properties(obj)
# then
self.assertListEqual(result, ['Auto Group'])
# actions
@@ -468,6 +515,136 @@ class TestGroupAdmin(TestCase):
self.assertFalse(Group.objects.filter(name="new group").exists())
class TestGroupAdminChangeFormSuperuserExclusiveEdits(WebTest):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.super_admin = User.objects.create_superuser("super_admin")
cls.staff_admin = User.objects.create_user("staff_admin")
cls.staff_admin.is_staff = True
cls.staff_admin.save()
cls.staff_admin = AuthUtils.add_permissions_to_user_by_name(
[
"auth.add_group",
"auth.change_group",
"auth.view_group",
"groupmanagement.add_group",
"groupmanagement.change_group",
"groupmanagement.view_group",
],
cls.staff_admin
)
cls.superuser_exclusive_fields = ["permissions", "authgroup-0-restricted"]
def test_should_show_all_fields_to_superuser_for_add(self):
# given
self.app.set_user(self.super_admin)
page = self.app.get("/admin/groupmanagement/group/add/")
# when
form = page.forms["group_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertIn(field, form.fields)
def test_should_not_show_all_fields_to_staff_admins_for_add(self):
# given
self.app.set_user(self.staff_admin)
page = self.app.get("/admin/groupmanagement/group/add/")
# when
form = page.forms["group_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertNotIn(field, form.fields)
def test_should_show_all_fields_to_superuser_for_change(self):
# given
self.app.set_user(self.super_admin)
group = Group.objects.create(name="Dummy group")
page = self.app.get(f"/admin/groupmanagement/group/{group.pk}/change/")
# when
form = page.forms["group_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertIn(field, form.fields)
def test_should_not_show_all_fields_to_staff_admin_for_change(self):
# given
self.app.set_user(self.staff_admin)
group = Group.objects.create(name="Dummy group")
page = self.app.get(f"/admin/groupmanagement/group/{group.pk}/change/")
# when
form = page.forms["group_form"]
# then
for field in self.superuser_exclusive_fields:
with self.subTest(field=field):
self.assertNotIn(field, form.fields)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class TestGroupAdmin2(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.superuser = User.objects.create_superuser("super")
def test_should_remove_users_from_state_groups(self):
# given
user_member = AuthUtils.create_user("Bruce Wayne")
character_member = AuthUtils.add_main_character_2(
user_member,
name="Bruce Wayne",
character_id=1001,
corp_id=2001,
corp_name="Wayne Technologies",
)
user_guest = AuthUtils.create_user("Lex Luthor")
AuthUtils.add_main_character_2(
user_guest,
name="Lex Luthor",
character_id=1011,
corp_id=2011,
corp_name="Luthor Corp",
)
member_state = AuthUtils.get_member_state()
member_state.member_characters.add(character_member)
user_member.refresh_from_db()
user_guest.refresh_from_db()
group = Group.objects.create(name="dummy")
user_member.groups.add(group)
user_guest.groups.add(group)
group.authgroup.states.add(member_state)
self.client.force_login(self.superuser)
# when
response = self.client.post(
f"/admin/groupmanagement/group/{group.pk}/change/",
data={
"name": f"{group.name}",
"authgroup-TOTAL_FORMS": "1",
"authgroup-INITIAL_FORMS": "1",
"authgroup-MIN_NUM_FORMS": "0",
"authgroup-MAX_NUM_FORMS": "1",
"authgroup-0-description": "",
"authgroup-0-states": f"{member_state.pk}",
"authgroup-0-internal": "on",
"authgroup-0-hidden": "on",
"authgroup-0-group": f"{group.pk}",
"authgroup-__prefix__-description": "",
"authgroup-__prefix__-internal": "on",
"authgroup-__prefix__-hidden": "on",
"authgroup-__prefix__-group": f"{group.pk}",
"_save": "Save"
}
)
# then
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "/admin/groupmanagement/group/")
self.assertIn(group, user_member.groups.all())
self.assertNotIn(group, user_guest.groups.all())
class TestReservedGroupNameAdmin(TestCase):
@classmethod
def setUpClass(cls):

View File

@@ -232,6 +232,38 @@ class TestAuthGroup(TestCase):
expected = 'Superheros'
self.assertEqual(str(group.authgroup), expected)
def test_should_remove_guests_from_group_when_restricted_to_members_only(self):
# given
user_member = AuthUtils.create_user("Bruce Wayne")
character_member = AuthUtils.add_main_character_2(
user_member,
name="Bruce Wayne",
character_id=1001,
corp_id=2001,
corp_name="Wayne Technologies",
)
user_guest = AuthUtils.create_user("Lex Luthor")
AuthUtils.add_main_character_2(
user_guest,
name="Lex Luthor",
character_id=1011,
corp_id=2011,
corp_name="Luthor Corp",
)
member_state = AuthUtils.get_member_state()
member_state.member_characters.add(character_member)
user_member.refresh_from_db()
user_guest.refresh_from_db()
group = Group.objects.create(name="dummy")
user_member.groups.add(group)
user_guest.groups.add(group)
group.authgroup.states.add(member_state)
# when
group.authgroup.remove_users_not_matching_states()
# then
self.assertIn(group, user_member.groups.all())
self.assertNotIn(group, user_guest.groups.all())
class TestAuthGroupRequestApprovers(TestCase):
def setUp(self) -> None:

View File

@@ -1,9 +1,3 @@
from .core import notify # noqa: F401
default_app_config = 'allianceauth.notifications.apps.NotificationsConfig'
def notify(
user: object, title: str, message: str = None, level: str = 'info'
) -> None:
"""Sends a new notification to user. Convenience function to manager pendant."""
from .models import Notification
Notification.objects.notify_user(user, title, message, level)

View File

@@ -0,0 +1,33 @@
class NotifyApiWrapper:
"""Wrapper to create notify API."""
def __call__(self, *args, **kwargs): # provide old API for backwards compatibility
return self._add_notification(*args, **kwargs)
def danger(self, user: object, title: str, message: str = None) -> None:
"""Add danger notification for user."""
self._add_notification(user, title, message, level="danger")
def info(self, user: object, title: str, message: str = None) -> None:
"""Add info notification for user."""
self._add_notification(user=user, title=title, message=message, level="info")
def success(self, user: object, title: str, message: str = None) -> None:
"""Add success notification for user."""
self._add_notification(user, title, message, level="success")
def warning(self, user: object, title: str, message: str = None) -> None:
"""Add warning notification for user."""
self._add_notification(user, title, message, level="warning")
def _add_notification(
self, user: object, title: str, message: str = None, level: str = "info"
) -> None:
from .models import Notification
Notification.objects.notify_user(
user=user, title=title, message=message, level=level
)
notify = NotifyApiWrapper()

View File

@@ -5,91 +5,34 @@
{% block page_title %}{% translate "Notifications" %}{% endblock %}
{% block content %}
<div class="col-lg-12">
<h1 class="page-header text-center">{% translate "Notifications" %}</h1>
<div class="col-lg-12 container" id="example">
<div class="row">
<div class="col-lg-12">
<div class="panel panel-default">
<div class="panel-heading">
<ul class="nav nav-pills">
<li class="active"><a data-toggle="pill" href="#unread">{% translate "Unread" %}
<b>({{ unread|length }})</b></a></li>
<li><a data-toggle="pill" href="#read">{% translate "Read" %} <b>({{ read|length }})</b></a>
</li>
<div class="pull-right">
<a href="{% url 'notifications:mark_all_read' %}" class="btn btn-primary">{% translate "Mark All Read" %}</a>
<a href="{% url 'notifications:delete_all_read' %}" class="btn btn-danger">{% translate "Delete All Read" %}</a>
</div>
</ul>
</div>
<div class="panel-body">
<div class="tab-content">
<div id="unread" class="tab-pane fade in active">
<div class="table-responsive">
{% if unread %}
<table class="table table-condensed table-hover table-striped">
<tr>
<th class="text-center">{% translate "Timestamp" %}</th>
<th class="text-center">{% translate "Title" %}</th>
<th class="text-center">{% translate "Action" %}</th>
</tr>
{% for notif in unread %}
<tr class="{{ notif.level }}">
<td class="text-center">{{ notif.timestamp }}</td>
<td class="text-center">{{ notif.title }}</td>
<td class="text-center">
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-success" title="View">
<span class="glyphicon glyphicon-eye-open"></span>
</a>
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="Remove">
<span class="glyphicon glyphicon-remove"></span>
</a>
</td>
</tr>
{% endfor %}
</table>
{% else %}
<div class="alert alert-warning text-center">{% translate "No unread notifications." %}</div>
{% endif %}
</div>
</div>
<div id="read" class="tab-pane fade">
<div class="panel-body">
<div class="table-responsive">
{% if read %}
<table class="table table-condensed table-hover table-striped">
<tr>
<th class="text-center">{% translate "Timestamp" %}</th>
<th class="text-center">{% translate "Title" %}</th>
<th class="text-center">{% translate "Action" %}</th>
</tr>
{% for notif in read %}
<tr class="{{ notif.level }}">
<td class="text-center">{{ notif.timestamp }}</td>
<td class="text-center">{{ notif.title }}</td>
<td class="text-center">
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-success" title="View">
<span class="glyphicon glyphicon-eye-open"></span>
</a>
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="remove">
<span class="glyphicon glyphicon-remove"></span>
</a>
</td>
</tr>
{% endfor %}
</table>
{% else %}
<div class="alert alert-warning text-center">{% translate "No read notifications." %}</div>
{% endif %}
</div>
</div>
</div>
</div>
</div>
</div>
<h1 class="page-header text-center">{% translate "Notifications" %}</h1>
<div class="panel panel-default">
<div class="panel-heading">
<ul class="nav nav-pills">
<li class="active"><a data-toggle="tab" href="#unread">{% translate "Unread" %}<b>({{ unread|length }})</b></a></li>
<li><a data-toggle="tab" href="#read">{% translate "Read" %} <b>({{ read|length }})</b></a></li>
<div class="pull-right">
<a href="{% url 'notifications:mark_all_read' %}" class="btn btn-warning">{% translate "Mark All Read" %}</a>
<a href="{% url 'notifications:delete_all_read' %}" class="btn btn-danger">{% translate "Delete All Read" %}</a>
</div>
</ul>
</div>
<div class="panel-body">
<div class="tab-content">
<div id="unread" class="tab-pane fade in active">
{% include "notifications/list_partial.html" with notifications=unread %}
</div>
<div id="read" class="tab-pane fade">
{% include "notifications/list_partial.html" with notifications=read %}
</div>
</div>
</div>
</div>
{% endblock %}

View File

@@ -0,0 +1,29 @@
{% load i18n %}
{% if notifications %}
<div class="table-responsive">
<table class="table table-condensed table-hover table-striped">
<tr>
<th class="text-center">{% translate "Timestamp" %}</th>
<th class="text-center">{% translate "Title" %}</th>
<th class="text-center">{% translate "Action" %}</th>
</tr>
{% for notif in notifications %}
<tr class="{{ notif.level }}">
<td class="text-center">{{ notif.timestamp }}</td>
<td class="text-center">{{ notif.title }}</td>
<td class="text-center">
<a href="{% url 'notifications:view' notif.id %}" class="btn btn-primary" title="View">
<span class="glyphicon glyphicon-eye-open"></span>
</a>
<a href="{% url 'notifications:remove' notif.id %}" class="btn btn-danger" title="Remove">
<span class="glyphicon glyphicon-remove"></span>
</a>
</td>
</tr>
{% endfor %}
</table>
</div>
{% else %}
<div class="alert alert-default text-center">{% translate "No notifications." %}</div>
{% endif %}

View File

@@ -5,25 +5,22 @@
{% block page_title %}{% translate "View Notification" %}{% endblock page_title %}
{% block content %}
<h1 class="page-header text-center">
{% translate "View Notification" %}
<div class="text-right">
<a href="{% url 'notifications:list' %}" class="btn btn-primary btn-lg">
<span class="glyphicon glyphicon-arrow-left"></span>
</a>
</div>
</h1>
<div class="col-lg-12">
<h1 class="page-header text-center">
{% translate "View Notification" %}
<div class="text-right">
<a href="{% url 'notifications:list' %}" class="btn btn-primary btn-lg">
<span class="glyphicon glyphicon-arrow-left"></span>
</a>
</div>
</h1>
<div class="col-lg-12 container">
<div class="row">
<div class="col-lg-12">
<div class="panel panel-{{ notif.level }}">
<div class="panel-heading">{{ notif.timestamp }} {{ notif.title }}</div>
<div class="panel-body"><pre>{{ notif.message }}</pre></div>
</div>
</div>
<div class="row">
<div class="col-lg-12">
<div class="panel panel-{{ notif.level }}">
<div class="panel-heading">{{ notif.timestamp }} {{ notif.title }}</div>
<div class="panel-body"><pre>{{ notif.message }}</pre></div>
</div>
</div>
</div>
{% endblock %}

View File

@@ -0,0 +1,85 @@
from django.test import TestCase
from allianceauth.tests.auth_utils import AuthUtils
from ..core import NotifyApiWrapper
from ..models import Notification
class TestUserNotificationCount(TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.user = AuthUtils.create_user("bruce_wayne")
def test_should_add_danger_notification(self):
# given
notify = NotifyApiWrapper()
# when
notify.danger(user=self.user, title="title", message="message")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.DANGER)
def test_should_add_info_notification(self):
# given
notify = NotifyApiWrapper()
# when
notify.info(user=self.user, title="title", message="message")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.INFO)
def test_should_add_success_notification(self):
# given
notify = NotifyApiWrapper()
# when
notify.success(user=self.user, title="title", message="message")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.SUCCESS)
def test_should_add_warning_notification(self):
# given
notify = NotifyApiWrapper()
# when
notify.warning(user=self.user, title="title", message="message")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.WARNING)
def test_should_add_info_notification_via_callable(self):
# given
notify = NotifyApiWrapper()
# when
notify(user=self.user, title="title", message="message")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.INFO)
def test_should_add_danger_notification_via_callable(self):
# given
notify = NotifyApiWrapper()
# when
notify(user=self.user, title="title", message="message", level="danger")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, Notification.Level.DANGER)

View File

@@ -4,11 +4,8 @@ from allianceauth.tests.auth_utils import AuthUtils
from .. import notify
from ..models import Notification
MODULE_PATH = 'allianceauth.notifications'
class TestUserNotificationCount(TestCase):
@classmethod
def setUpTestData(cls):
cls.user = AuthUtils.create_user('magic_mike')
@@ -23,6 +20,18 @@ class TestUserNotificationCount(TestCase):
alliance_name='RIDERS'
)
def test_can_notify(self):
notify(self.user, 'dummy')
def test_can_notify_short(self):
# when
notify(self.user, "dummy")
# then
self.assertEqual(Notification.objects.filter(user=self.user).count(), 1)
def test_can_notify_full(self):
# when
notify(user=self.user, title="title", message="message", level="danger")
# then
obj = Notification.objects.first()
self.assertEqual(obj.user, self.user)
self.assertEqual(obj.title, "title")
self.assertEqual(obj.message, "message")
self.assertEqual(obj.level, "danger")

View File

@@ -73,6 +73,8 @@
],
bootstrap: true
},
"stateSave": true,
"stateDuration": 0,
drawCallback: function ( settings ) {
let api = this.api();
let rows = api.rows( {page:'current'} ).nodes();

View File

@@ -106,8 +106,10 @@
idx: 1
}
],
bootstrap: true
bootstrap: true,
},
"stateSave": true,
"stateDuration": 0,
drawCallback: function ( settings ) {
let api = this.api();
let rows = api.rows( {page:'current'} ).nodes();

View File

@@ -3,11 +3,11 @@ from django.contrib import admin
from allianceauth import hooks
from allianceauth.authentication.admin import (
MainAllianceFilter,
MainCorporationsFilter,
user_main_organization,
user_profile_pic,
user_username,
user_main_organization,
MainCorporationsFilter,
MainAllianceFilter
)
from .models import NameFormatConfig
@@ -36,19 +36,18 @@ class ServicesUserAdmin(admin.ModelAdmin):
MainAllianceFilter,
'user__date_joined',
)
list_select_related = (
'user', 'user__profile__main_character', 'user__profile__state'
)
@admin.display(ordering='user__profile__state__name')
def _state(self, obj):
return obj.user.profile.state.name
_state.short_description = 'state'
_state.admin_order_field = 'user__profile__state__name'
@admin.display(ordering='user__date_joined')
def _date_joined(self, obj):
return obj.user.date_joined
_date_joined.short_description = 'date joined'
_date_joined.admin_order_field = 'user__date_joined'
class NameFormatConfigForm(forms.ModelForm):
def __init__(self, *args, **kwargs):
@@ -62,6 +61,7 @@ class NameFormatConfigForm(forms.ModelForm):
self.fields['service_name'] = forms.ChoiceField(choices=SERVICE_CHOICES)
@admin.register(NameFormatConfig)
class NameFormatConfigAdmin(admin.ModelAdmin):
form = NameFormatConfigForm
list_display = ('service_name', 'get_state_display_string')
@@ -69,6 +69,3 @@ class NameFormatConfigAdmin(admin.ModelAdmin):
def get_state_display_string(self, obj):
return ', '.join([state.name for state in obj.states.all()])
get_state_display_string.short_description = 'States'
admin.site.register(NameFormatConfig, NameFormatConfigAdmin)

View File

@@ -2,12 +2,11 @@ import logging
from django.contrib import admin
from . import __title__
from ...admin import ServicesUserAdmin
from . import __title__
from .models import DiscordUser
from .utils import LoggerAddTag
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
@@ -18,21 +17,16 @@ class DiscordUserAdmin(ServicesUserAdmin):
list_filter = ServicesUserAdmin.list_filter + ('activated',)
ordering = ('-activated',)
def _uid(self, obj):
return obj.uid
_uid.short_description = 'Discord ID (UID)'
_uid.admin_order_field = 'uid'
def _username(self, obj):
if obj.username and obj.discriminator:
return f'{obj.username}#{obj.discriminator}'
else:
return ''
def delete_queryset(self, request, queryset):
for user in queryset:
user.delete_user()
_username.short_description = 'Discord Username'
_username.admin_order_field = 'username'
@admin.display(description='Discord ID (UID)', ordering='uid')
def _uid(self, obj):
return obj.uid
@admin.display(description='Discord Username', ordering='username')
def _username(self, obj):
if obj.username and obj.discriminator:
return f'{obj.username}#{obj.discriminator}'
return ''

View File

@@ -0,0 +1,37 @@
"""Public interface for community apps who want to interact with the Discord server
of the current Alliance Auth instance.
Example
=======
Here is an example for using the api to fetch the current roles from the configured Discord server.
.. code-block:: python
from allianceauth.services.modules.discord.api import create_bot_client, discord_guild_id
client = create_bot_client() # create a new Discord client
guild_id = discord_guild_id() # get the ID of the configured Discord server
roles = client.guild_roles(guild_id) # fetch the roles from our Discord server
.. seealso::
The docs for the client class can be found here: :py:class:`~allianceauth.services.modules.discord.discord_client.client.DiscordClient`
"""
from typing import Optional
from .app_settings import DISCORD_GUILD_ID
from .core import create_bot_client, group_to_role, server_name # noqa
from .discord_client.models import Role # noqa
from .models import DiscordUser # noqa
__all__ = ["create_bot_client", "group_to_role", "server_name", "DiscordUser", "Role"]
def discord_guild_id() -> Optional[int]:
"""Guild ID of configured Discord server.
Returns:
Guild ID or ``None`` if not configured
"""
return int(DISCORD_GUILD_ID) if DISCORD_GUILD_ID else None

View File

@@ -2,16 +2,25 @@ from .utils import clean_setting
DISCORD_APP_ID = clean_setting('DISCORD_APP_ID', '')
"""App ID for the AA bot on Discord. Needs to be set."""
DISCORD_APP_SECRET = clean_setting('DISCORD_APP_SECRET', '')
"""App secret for the AA bot on Discord. Needs to be set."""
DISCORD_BOT_TOKEN = clean_setting('DISCORD_BOT_TOKEN', '')
"""Token used by the AA bot on Discord. Needs to be set."""
DISCORD_CALLBACK_URL = clean_setting('DISCORD_CALLBACK_URL', '')
"""Callback URL for OAuth with Discord. Needs to be set."""
DISCORD_GUILD_ID = clean_setting('DISCORD_GUILD_ID', '')
"""ID of the Discord Server. Needs to be set."""
# max retries of tasks after an error occurred
DISCORD_TASKS_MAX_RETRIES = clean_setting('DISCORD_TASKS_MAX_RETRIES', 3)
"""Max retries of tasks after an error occurred."""
# Pause in seconds until next retry for tasks after the API returned an error
DISCORD_TASKS_RETRY_PAUSE = clean_setting('DISCORD_TASKS_RETRY_PAUSE', 60)
"""Pause in seconds until next retry for tasks after the API returned an error."""
# automatically sync Discord users names to user's main character name when created
DISCORD_SYNC_NAMES = clean_setting('DISCORD_SYNC_NAMES', False)
"""Automatically sync Discord users names to user's main character name when created."""

View File

@@ -6,6 +6,7 @@ from django.template.loader import render_to_string
from allianceauth import hooks
from allianceauth.services.hooks import ServicesHook
from .core import server_name, user_formatted_nick
from .models import DiscordUser
from .urls import urlpatterns
from .utils import LoggerAddTag
@@ -53,7 +54,7 @@ class DiscordService(ServicesHook):
return render_to_string(
self.service_ctrl_template,
{
'server_name': DiscordUser.objects.server_name(),
'server_name': server_name(),
'user_has_account': user_has_account,
'discord_username': discord_username
},
@@ -73,7 +74,7 @@ class DiscordService(ServicesHook):
'user_pk': user.pk,
# since the new nickname is not yet in the DB we need to
# provide it manually to the task
'nickname': DiscordUser.objects.user_formatted_nick(user)
'nickname': user_formatted_nick(user)
},
priority=SINGLE_TASK_PRIORITY
)

View File

@@ -0,0 +1,129 @@
"""Core functionality of the Discord service not directly related to models."""
import logging
from typing import List, Optional, Tuple
from requests.exceptions import HTTPError
from django.contrib.auth.models import Group, User
from allianceauth.groupmanagement.models import ReservedGroupName
from allianceauth.services.hooks import NameFormatter
from . import __title__
from .app_settings import DISCORD_BOT_TOKEN, DISCORD_GUILD_ID
from .discord_client import DiscordClient, RolesSet, Role
from .discord_client.exceptions import DiscordClientException
from .utils import LoggerAddTag
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
def create_bot_client(is_rate_limited: bool = True) -> DiscordClient:
"""Create new bot client for accessing the configured Discord server.
Args:
is_rate_limited: Set to False to turn off rate limiting (use with care).
Return:
Discord client instance
"""
return DiscordClient(DISCORD_BOT_TOKEN, is_rate_limited=is_rate_limited)
def calculate_roles_for_user(
user: User,
client: DiscordClient,
discord_uid: int,
state_name: str = None,
) -> Tuple[RolesSet, Optional[bool]]:
"""Calculate current Discord roles for an Auth user.
Takes into account reserved groups and existing managed roles (e.g. nitro).
Returns:
- Discord roles, changed flag:
- True when roles have changed,
- False when they have not changed,
- None if user is not a member of the guild
"""
roles_calculated = client.match_or_create_roles_from_names_2(
guild_id=DISCORD_GUILD_ID,
role_names=_user_group_names(user=user, state_name=state_name),
)
logger.debug("Calculated roles for user %s: %s", user, roles_calculated.ids())
roles_current = client.guild_member_roles(
guild_id=DISCORD_GUILD_ID, user_id=discord_uid
)
if roles_current is None:
logger.debug("User %s is not a member of the guild.", user)
return roles_calculated, None
logger.debug("Current roles user %s: %s", user, roles_current.ids())
reserved_role_names = ReservedGroupName.objects.values_list("name", flat=True)
roles_reserved = roles_current.subset(role_names=reserved_role_names)
roles_managed = roles_current.subset(managed_only=True)
roles_persistent = roles_managed.union(roles_reserved)
if roles_calculated == roles_current.difference(roles_persistent):
return roles_calculated, False
return roles_calculated.union(roles_persistent), True
def _user_group_names(user: User, state_name: str = None) -> List[str]:
"""Names of groups and state the given user is a member of."""
if not state_name:
state_name = user.profile.state.name
group_names = [group.name for group in user.groups.all()] + [state_name]
logger.debug("Group names for roles updates of user %s are: %s", user, group_names)
return group_names
def user_formatted_nick(user: User) -> Optional[str]:
"""Name of the given user's main character with name formatting applied.
Returns:
Name or ``None`` if user has no main.
"""
from .auth_hooks import DiscordService
if user.profile.main_character:
return NameFormatter(DiscordService(), user).format_name()
return None
def group_to_role(group: Group) -> Optional[Role]:
"""Fetch the Discord role matching the given Django group by name.
Returns:
Discord role or None if no matching role exist
"""
return default_bot_client.match_role_from_name(
guild_id=DISCORD_GUILD_ID, role_name=group.name
)
def server_name(use_cache: bool = True) -> str:
"""Fetches the name of the current Discord server.
Args:
use_cache: When set False will force an API call to get the server name
Returns:
Server name or an empty string if the name could not be retrieved
"""
try:
server_name = default_bot_client.guild_name(
guild_id=DISCORD_GUILD_ID, use_cache=use_cache
)
except (HTTPError, DiscordClientException):
server_name = ""
except Exception:
logger.warning(
"Unexpected error when trying to retrieve the server name from Discord",
exc_info=True,
)
server_name = ""
return server_name
# Default bot client to be used by modules of this package
default_bot_client = create_bot_client()

View File

@@ -1,3 +1,10 @@
from .client import DiscordClient # noqa
from .exceptions import DiscordApiBackoff # noqa
from .helpers import DiscordRoles # noqa
from .app_settings import DISCORD_OAUTH_BASE_URL, DISCORD_OAUTH_TOKEN_URL # noqa
from .client import DiscordClient # noqa
from .exceptions import ( # noqa
DiscordApiBackoff,
DiscordClientException,
DiscordRateLimitExhausted,
DiscordTooManyRequestsError,
)
from .helpers import RolesSet # noqa
from .models import Guild, GuildMember, Role, User # noqa

View File

@@ -1,45 +1,56 @@
"""Settings for the Discord client.
To overwrite a default set the variable in your local Django settings, e.g:
.. code:: python
DISCORD_GUILD_NAME_CACHE_MAX_AGE = 7200
"""
from ..utils import clean_setting
# Base URL for all API calls. Must end with /.
DISCORD_API_BASE_URL = clean_setting(
'DISCORD_API_BASE_URL', 'https://discord.com/api/'
)
"""Base URL for all API calls. Must end with /."""
# Low level connecttimeout for requests to the Discord API in seconds
DISCORD_API_TIMEOUT_CONNECT = clean_setting(
'DISCORD_API_TIMEOUT', 5
)
"""Low level connect timeout for requests to the Discord API in seconds."""
# Low level read timeout for requests to the Discord API in seconds
DISCORD_API_TIMEOUT_READ = clean_setting(
'DISCORD_API_TIMEOUT', 30
)
"""Low level read timeout for requests to the Discord API in seconds."""
# Base authorization URL for Discord Oauth
DISCORD_OAUTH_BASE_URL = clean_setting(
'DISCORD_OAUTH_BASE_URL', 'https://discord.com/api/oauth2/authorize'
)
"""Base authorization URL for Discord Oauth."""
# Base authorization URL for Discord Oauth
DISCORD_OAUTH_TOKEN_URL = clean_setting(
'DISCORD_OAUTH_TOKEN_URL', 'https://discord.com/api/oauth2/token'
)
"""Base authorization URL for Discord Oauth."""
# How long the Discord guild names retrieved from the server are
# caches locally in seconds.
DISCORD_GUILD_NAME_CACHE_MAX_AGE = clean_setting(
'DISCORD_GUILD_NAME_CACHE_MAX_AGE', 3600 * 24
)
"""How long the Discord guild names retrieved from the server
are caches locally in seconds.
"""
# How long Discord roles retrieved from the server are caches locally in seconds.
DISCORD_ROLES_CACHE_MAX_AGE = clean_setting(
'DISCORD_ROLES_CACHE_MAX_AGE', 3600 * 1
)
"""How long Discord roles retrieved from the server are caches locally in seconds."""
# Turns off creation of new roles. In case the rate limit for creating roles is
# exhausted, this setting allows the Discord service to continue to function
# and wait out the reset. Rate limit is about 250 per 48 hrs.
DISCORD_DISABLE_ROLE_CREATION = clean_setting(
'DISCORD_DISABLE_ROLE_CREATION', False
)
"""Turns off creation of new roles. In case the rate limit for creating roles is
exhausted, this setting allows the Discord service to continue to function
and wait out the reset. Rate limit is about 250 per 48 hrs.
"""

View File

@@ -1,32 +1,37 @@
from hashlib import md5
"""Client for interacting with the Discord API."""
import json
import logging
from enum import IntEnum
from hashlib import md5
from http import HTTPStatus
from time import sleep
from typing import Iterable, List, Optional, Set, Tuple
from urllib.parse import urljoin
from uuid import uuid1
from redis import Redis
import requests
from requests.exceptions import HTTPError
from redis import Redis
from django.core.cache import caches
from allianceauth.utils.cache import get_redis_client
from allianceauth import __title__ as AUTH_TITLE, __url__, __version__
from allianceauth import __title__ as AUTH_TITLE
from allianceauth import __url__, __version__
from .. import __title__
from ..utils import LoggerAddTag
from .app_settings import (
DISCORD_API_BASE_URL,
DISCORD_API_TIMEOUT_CONNECT,
DISCORD_API_TIMEOUT_READ,
DISCORD_DISABLE_ROLE_CREATION,
DISCORD_GUILD_NAME_CACHE_MAX_AGE,
DISCORD_OAUTH_BASE_URL,
DISCORD_OAUTH_TOKEN_URL,
DISCORD_ROLES_CACHE_MAX_AGE,
)
from .exceptions import DiscordRateLimitExhausted, DiscordTooManyRequestsError
from .helpers import DiscordRoles
from ..utils import LoggerAddTag
from .helpers import RolesSet
from .models import Guild, GuildMember, Role, User
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
@@ -58,8 +63,13 @@ MINIMUM_BLOCKING_WAIT = 50
RATE_LIMIT_RETRIES = 1000
class DiscordApiStatusCode(IntEnum):
"""Status code returned from the Discord API."""
UNKNOWN_MEMBER = 10007 #:
class DiscordClient:
"""This class provides a web client for interacting with the Discord API
"""This class provides a web client for interacting with the Discord API.
The client has rate limiting that supports concurrency.
This means it is able to ensure the API rate limit is not violated,
@@ -67,24 +77,30 @@ class DiscordClient:
In addition the client support proper API backoff.
Synchronization of rate limit infos accross multiple processes
Synchronization of rate limit infos across multiple processes
is implemented with Redis and thus requires Redis as Django cache backend.
All durations are in milliseconds.
"""
OAUTH_BASE_URL = DISCORD_OAUTH_BASE_URL
OAUTH_TOKEN_URL = DISCORD_OAUTH_TOKEN_URL
The cache is shared across all clients and processes (also using Redis).
All durations are in milliseconds.
Most errors from the API will raise a requests.HTTPError.
Args:
access_token: Discord access token used to authenticate all calls to the API
redis: Redis instance to be used.
is_rate_limited: Set to False to turn off rate limiting (use with care).
If not specified will try to use the Redis instance
from the default Django cache backend.
Raises:
ValueError: No access token provided
"""
_KEY_GLOBAL_BACKOFF_UNTIL = 'DISCORD_GLOBAL_BACKOFF_UNTIL'
_KEY_GLOBAL_RATE_LIMIT_REMAINING = 'DISCORD_GLOBAL_RATE_LIMIT_REMAINING'
_KEYPREFIX_GUILD_NAME = 'DISCORD_GUILD_NAME'
_KEYPREFIX_GUILD_ROLES = 'DISCORD_GUILD_ROLES'
_KEYPREFIX_ROLE_NAME = 'DISCORD_ROLE_NAME'
_NICK_MAX_CHARS = 32
_HTTP_STATUS_CODE_NOT_FOUND = 404
_HTTP_STATUS_CODE_RATE_LIMITED = 429
_DISCORD_STATUS_CODE_UNKNOWN_MEMBER = 10007
def __init__(
self,
@@ -92,19 +108,12 @@ class DiscordClient:
redis: Redis = None,
is_rate_limited: bool = True
) -> None:
"""
Params:
- access_token: Discord access token used to authenticate all calls to the API
- redis: Redis instance to be used.
- is_rate_limited: Set to False to run of rate limiting (use with care)
If not specified will try to use the Redis instance
from the default Django cache backend.
"""
if not access_token:
raise ValueError('You must provide an access token.')
self._access_token = str(access_token)
self._is_rate_limited = bool(is_rate_limited)
if not redis:
default_cache = caches['default']
self._redis = default_cache.get_master_client()
self._redis = get_redis_client()
if not isinstance(self._redis, Redis):
raise RuntimeError(
'This class requires a Redis client, but none was provided '
@@ -132,19 +141,20 @@ class DiscordClient:
self.__redis_script_set_longer = self._redis.register_script(lua_2)
@property
def access_token(self):
def access_token(self) -> str:
"""Discord access token."""
return self._access_token
@property
def is_rate_limited(self):
def is_rate_limited(self) -> bool:
"""Wether this instance is rate limited."""
return self._is_rate_limited
def __repr__(self):
return f'{type(self).__name__}(access_token=...{self.access_token[-5:]})'
def _redis_decr_or_set(self, name: str, value: str, px: int) -> bool:
"""decreases the key value if it exists and returns the result
else sets the key
"""Decrease the key value if it exists and returns the result else set the key.
Implemented as Lua script to ensure atomicity.
"""
@@ -153,7 +163,7 @@ class DiscordClient:
)
def _redis_set_if_longer(self, name: str, value: str, px: int) -> bool:
"""like set, but only goes through if either key doesn't exist
"""Like set, but only goes through if either key doesn't exist
or px would be extended.
Implemented as Lua script to ensure atomicity.
@@ -164,111 +174,134 @@ class DiscordClient:
# users
def current_user(self) -> dict:
"""returns the user belonging to the current access_token"""
def current_user(self) -> User:
"""Fetch user belonging to the current access_token."""
authorization = f'Bearer {self.access_token}'
r = self._api_request(
method='get', route='users/@me', authorization=authorization
)
return r.json()
return User.from_dict(r.json())
# guild
def guild_infos(self, guild_id: int) -> dict:
"""Returns all basic infos about this guild"""
def guild_infos(self, guild_id: int) -> Guild:
"""Fetch all basic infos about this guild.
Args:
guild_id: Discord ID of the guild
"""
route = f"guilds/{guild_id}"
r = self._api_request(method='get', route=route)
return r.json()
return Guild.from_dict(r.json())
def guild_name(self, guild_id: int, use_cache: bool = True) -> str:
"""returns the name of this guild (cached)
or an empty string if something went wrong
"""Fetch the name of this guild (cached).
Params:
- guild_id: ID of current guild
- use_cache: When set to False will force an API call to get the server name
Args:
guild_id: Discord ID of the guild
use_cache: When set to False will force an API call to get the server name
Returns:
Name of the server or an empty string if something went wrong.
"""
key_name = self._guild_name_cache_key(guild_id)
if use_cache:
guild_name = self._redis_decode(self._redis.get(key_name))
else:
guild_name = None
guild_name = ""
if not guild_name:
guild_infos = self.guild_infos(guild_id)
if 'name' in guild_infos:
guild_name = guild_infos['name']
self._redis.set(
name=key_name,
value=guild_name,
ex=DISCORD_GUILD_NAME_CACHE_MAX_AGE
)
try:
guild = self.guild_infos(guild_id)
except HTTPError:
guild_name = ""
else:
guild_name = ''
guild_name = guild.name
self._redis.set(
name=key_name, value=guild_name, ex=DISCORD_GUILD_NAME_CACHE_MAX_AGE
)
return guild_name
@classmethod
def _guild_name_cache_key(cls, guild_id: int) -> str:
"""Returns key for accessing role given by name in the role cache"""
"""Construct key for accessing role given by name in the role cache.
Args:
guild_id: Discord ID of the guild
"""
gen_key = DiscordClient._generate_hash(f'{guild_id}')
return f'{cls._KEYPREFIX_GUILD_NAME}__{gen_key}'
# guild roles
def guild_roles(self, guild_id: int, use_cache: bool = True) -> list:
"""Returns the list of all roles for this guild
def guild_roles(self, guild_id: int, use_cache: bool = True) -> Set[Role]:
"""Fetch all roles for this guild.
If use_cache is set to False it will always hit the API to retrieve
fresh data and update the cache
Args:
guild_id: Discord ID of the guild
use_cache: If is set to False it will always hit the API to retrieve
fresh data and update the cache.
Returns:
"""
cache_key = self._guild_roles_cache_key(guild_id)
roles = None
if use_cache:
roles_raw = self._redis.get(name=cache_key)
if roles_raw:
logger.debug('Returning roles for guild %s from cache', guild_id)
return json.loads(self._redis_decode(roles_raw))
else:
logger.debug('No roles for guild %s in cache', guild_id)
route = f"guilds/{guild_id}/roles"
r = self._api_request(method='get', route=route)
roles = r.json()
if roles and isinstance(roles, list):
roles = json.loads(self._redis_decode(roles_raw))
logger.debug('No roles for guild %s in cache', guild_id)
if roles is None:
route = f"guilds/{guild_id}/roles"
r = self._api_request(method='get', route=route)
roles = r.json()
if not roles or not isinstance(roles, list):
raise RuntimeError(
f"Unexpected response when fetching roles from API: {roles}"
)
self._redis.set(
name=cache_key,
value=json.dumps(roles),
ex=DISCORD_ROLES_CACHE_MAX_AGE
)
return roles
return {Role.from_dict(role) for role in roles}
def create_guild_role(self, guild_id: int, role_name: str, **kwargs) -> dict:
def create_guild_role(
self, guild_id: int, role_name: str, **kwargs
) -> Optional[Role]:
"""Create a new guild role with the given name.
See official documentation for additional optional parameters.
Note that Discord allows the creation of multiple roles with the same name,
so to avoid duplicates it's important to check existing roles
before creating new one
returns a new role dict on success
Args:
guild_id: Discord ID of the guild
role_name: Name of new role to create
Returns:
new role on success
"""
route = f"guilds/{guild_id}/roles"
data = {'name': DiscordRoles.sanitize_role_name(role_name)}
data = {'name': Role.sanitize_name(role_name)}
data.update(kwargs)
r = self._api_request(method='post', route=route, data=data)
role = r.json()
if role:
self._invalidate_guild_roles_cache(guild_id)
return role
return Role.from_dict(role)
return None
def delete_guild_role(self, guild_id: int, role_id: int) -> bool:
"""Deletes a guild role"""
"""Delete a guild role."""
route = f"guilds/{guild_id}/roles/{role_id}"
r = self._api_request(method='delete', route=route)
if r.status_code == 204:
self._invalidate_guild_roles_cache(guild_id)
return True
else:
return False
return False
def _invalidate_guild_roles_cache(self, guild_id: int) -> None:
cache_key = self._guild_roles_cache_key(guild_id)
@@ -277,67 +310,79 @@ class DiscordClient:
@classmethod
def _guild_roles_cache_key(cls, guild_id: int) -> str:
"""Returns key for accessing cached roles for a guild"""
"""Construct key for accessing cached roles for a guild.
Args:
guild_id: Discord ID of the guild
"""
gen_key = cls._generate_hash(f'{guild_id}')
return f'{cls._KEYPREFIX_GUILD_ROLES}__{gen_key}'
def match_role_from_name(self, guild_id: int, role_name: str) -> dict:
"""returns Discord role matching the given name or an empty dict"""
guild_roles = DiscordRoles(self.guild_roles(guild_id))
def match_role_from_name(self, guild_id: int, role_name: str) -> Optional[Role]:
"""Fetch Discord role matching the given name (cached).
Args:
guild_id: Discord ID of the guild
role_name: Name of role
Returns:
Matching role or None if no match is found
"""
guild_roles = RolesSet(self.guild_roles(guild_id))
return guild_roles.role_by_name(role_name)
def match_or_create_roles_from_names(self, guild_id: int, role_names: list) -> list:
"""returns Discord roles matching the given names
Returns as list of tuple of role and created flag
def match_or_create_roles_from_names(
self, guild_id: int, role_names: Iterable[str]
) -> List[Tuple[Role, bool]]:
"""Fetch or create Discord roles matching the given names (cached).
Will try to match with existing roles names
Non-existing roles will be created, then created flag will be True
Params:
- guild_id: ID of guild
- role_names: list of name strings each defining a role
Args:
guild_id: ID of guild
role_names: list of name strings each defining a role
Returns:
List of tuple of Role and created flag
"""
roles = list()
guild_roles = DiscordRoles(self.guild_roles(guild_id))
role_names_cleaned = {
DiscordRoles.sanitize_role_name(name) for name in role_names
}
guild_roles = RolesSet(self.guild_roles(guild_id))
role_names_cleaned = {Role.sanitize_name(name) for name in role_names}
for role_name in role_names_cleaned:
role, created = self.match_or_create_role_from_name(
guild_id=guild_id,
role_name=DiscordRoles.sanitize_role_name(role_name),
guild_roles=guild_roles
guild_id=guild_id, role_name=role_name, guild_roles=guild_roles
)
if role:
roles.append((role, created))
if created:
guild_roles = guild_roles.union(DiscordRoles([role]))
guild_roles = guild_roles.union(RolesSet([role]))
return roles
def match_or_create_role_from_name(
self, guild_id: int, role_name: str, guild_roles: DiscordRoles = None
) -> tuple:
"""returns Discord role matching the given name
Returns as tuple of role and created flag
self, guild_id: int, role_name: str, guild_roles: RolesSet = None
) -> Tuple[Role, bool]:
"""Fetch or create Discord role matching the given name.
Will try to match with existing roles names
Non-existing roles will be created, then created flag will be True
Params:
- guild_id: ID of guild
- role_name: strings defining name of a role
- guild_roles: All known guild roles as DiscordRoles object.
Helps to void redundant lookups of guild roles
when this method is used multiple times.
Args:
guild_id: ID of guild
role_name: strings defining name of a role
guild_roles: All known guild roles as RolesSet object.
Helps to void redundant lookups of guild roles
when this method is used multiple times.
Returns:
Tuple of Role and created flag
"""
if not isinstance(role_name, str):
raise TypeError('role_name must be of type string')
created = False
if guild_roles is None:
guild_roles = DiscordRoles(self.guild_roles(guild_id))
guild_roles = RolesSet(self.guild_roles(guild_id))
role = guild_roles.role_by_name(role_name)
if not role:
if not DISCORD_DISABLE_ROLE_CREATION:
@@ -346,9 +391,24 @@ class DiscordClient:
created = True
else:
role = None
return role, created
def match_or_create_roles_from_names_2(
self, guild_id: int, role_names: Iterable[str]
) -> RolesSet:
"""Fetch or create Discord role matching the given name.
Wrapper for ``match_or_create_role_from_name()``
Returns:
Roles as RolesSet object.
"""
return RolesSet.create_from_matched_roles(
self.match_or_create_roles_from_names(
guild_id=guild_id, role_names=role_names
)
)
# guild members
def add_guild_member(
@@ -358,13 +418,13 @@ class DiscordClient:
access_token: str,
role_ids: list = None,
nick: str = None
) -> bool:
"""Adds a user to the guilds.
) -> Optional[bool]:
"""Adds a user to the guild.
Returns:
- True when a new user was added
- None if the user already existed
- False when something went wrong or raises exception
- True when a new user was added
- None if the user already existed
- False when something went wrong or raises exception
"""
route = f"guilds/{guild_id}/members/{user_id}"
data = {
@@ -372,42 +432,49 @@ class DiscordClient:
}
if role_ids:
data['roles'] = self._sanitize_role_ids(role_ids)
if nick:
data['nick'] = str(nick)[:self._NICK_MAX_CHARS]
data['nick'] = GuildMember.sanitize_nick(nick)
r = self._api_request(method='put', route=route, data=data)
r.raise_for_status()
if r.status_code == 201:
return True
elif r.status_code == 204:
return None
else:
return False
return False
def guild_member(self, guild_id: int, user_id: int) -> dict:
"""returns the user info for a guild member
def guild_member(self, guild_id: int, user_id: int) -> Optional[GuildMember]:
"""Fetch info for a guild member.
or None if the user is not a member of the guild
Args:
guild_id: Discord ID of the guild
user_id: Discord ID of the user
Returns:
guild member or ``None`` if the user is not a member of the guild
"""
route = f'guilds/{guild_id}/members/{user_id}'
r = self._api_request(method='get', route=route, raise_for_status=False)
if self._is_member_unknown_error(r):
logger.warning("Discord user ID %s could not be found on server.", user_id)
return None
else:
r.raise_for_status()
return r.json()
r.raise_for_status()
return GuildMember.from_dict(r.json())
def modify_guild_member(
self, guild_id: int, user_id: int, role_ids: list = None, nick: str = None
) -> bool:
"""Modify attributes of a guild member.
self, guild_id: int, user_id: int, role_ids: List[int] = None, nick: str = None
) -> Optional[bool]:
"""Set properties of a guild member.
Args:
guild_id: Discord ID of the guild
user_id: Discord ID of the user
roles_id: New list of role IDs (if provided)
nick: New nickname (if provided)
Returns
- True when successful
- None if user is not a member of this guild
- False otherwise
- True when successful
- None if user is not a member of this guild
- False otherwise
"""
if not role_ids and not nick:
raise ValueError('Must specify role_ids or nick')
@@ -420,7 +487,7 @@ class DiscordClient:
data['roles'] = self._sanitize_role_ids(role_ids)
if nick:
data['nick'] = self._sanitize_nick(nick)
data['nick'] = GuildMember.sanitize_nick(nick)
route = f"guilds/{guild_id}/members/{user_id}"
r = self._api_request(
@@ -429,21 +496,22 @@ class DiscordClient:
if self._is_member_unknown_error(r):
logger.warning('User ID %s is not a member of this guild', user_id)
return None
else:
r.raise_for_status()
r.raise_for_status()
if r.status_code == 204:
return True
else:
return False
return False
def remove_guild_member(self, guild_id: int, user_id: int) -> bool:
"""Remove a member from a guild
def remove_guild_member(self, guild_id: int, user_id: int) -> Optional[bool]:
"""Remove a member from a guild.
Args:
guild_id: Discord ID of the guild
user_id: Discord ID of the user
Returns:
- True when successful
- None if member does not exist
- False otherwise
- True when successful
- None if member does not exist
- False otherwise
"""
route = f"guilds/{guild_id}/members/{user_id}"
r = self._api_request(
@@ -452,19 +520,16 @@ class DiscordClient:
if self._is_member_unknown_error(r):
logger.warning('User ID %s is not a member of this guild', user_id)
return None
else:
r.raise_for_status()
r.raise_for_status()
if r.status_code == 204:
return True
else:
return False
return False
# Guild member roles
def add_guild_member_role(
self, guild_id: int, user_id: int, role_id: int
) -> bool:
) -> Optional[bool]:
"""Adds a role to a guild member
Returns:
@@ -477,43 +542,69 @@ class DiscordClient:
if self._is_member_unknown_error(r):
logger.warning('User ID %s is not a member of this guild', user_id)
return None
else:
r.raise_for_status()
r.raise_for_status()
if r.status_code == 204:
return True
else:
return False
return False
def remove_guild_member_role(
self, guild_id: int, user_id: int, role_id: int
) -> bool:
"""Removes a role to a guild member
) -> Optional[bool]:
"""Remove a role to a guild member
Args:
guild_id: Discord ID of the guild
user_id: Discord ID of the user
role_id: Discord ID of role to be removed
Returns:
- True when successful
- None if member does not exist
- False otherwise
- True when successful
- None if member does not exist
- False otherwise
"""
route = f"guilds/{guild_id}/members/{user_id}/roles/{role_id}"
r = self._api_request(method='delete', route=route, raise_for_status=False)
if self._is_member_unknown_error(r):
logger.warning('User ID %s is not a member of this guild', user_id)
return None
else:
r.raise_for_status()
r.raise_for_status()
if r.status_code == 204:
return True
else:
return False
return False
def guild_member_roles(self, guild_id: int, user_id: int) -> Optional[RolesSet]:
"""Fetch the current guild roles of a guild member.
Args:
- guild_id: Discord guild ID
- user_id: Discord user ID
Returns:
- Member roles
- None if user is not a member of the guild
"""
member_info = self.guild_member(guild_id=guild_id, user_id=user_id)
if member_info is None:
return None # User is no longer a member
guild_roles = RolesSet(self.guild_roles(guild_id=guild_id))
logger.debug('Current guild roles: %s', guild_roles.ids())
if not guild_roles.has_roles(member_info.roles):
guild_roles = RolesSet(
self.guild_roles(guild_id=guild_id, use_cache=False)
)
if not guild_roles.has_roles(member_info.roles):
role_ids = set(member_info.roles).difference(guild_roles.ids())
raise RuntimeError(
f'Discord user {user_id} has unknown roles: {role_ids}'
)
return guild_roles.subset(member_info.roles)
@classmethod
def _is_member_unknown_error(cls, r: requests.Response) -> bool:
try:
result = (
r.status_code == cls._HTTP_STATUS_CODE_NOT_FOUND
and r.json()['code'] == cls._DISCORD_STATUS_CODE_UNKNOWN_MEMBER
r.status_code == HTTPStatus.NOT_FOUND
and r.json()['code'] == DiscordApiStatusCode.UNKNOWN_MEMBER
)
except (ValueError, KeyError):
result = False
@@ -530,7 +621,19 @@ class DiscordClient:
authorization: str = None,
raise_for_status: bool = True
) -> requests.Response:
"""Core method for performing all API calls"""
"""Core method for performing all API calls.
Args:
method: HTTP method of the request, e.g. "get"
route: Route in the Discord API, e.g. "users/@me"
data: Data to be send with the request
authorization: The authorization string to be used.
Will use the default bot token if not set.
raise_for_status: Whether a requests exception is to be raised when not ok
Returns:
The raw response from the API
"""
uid = uuid1().hex
if not hasattr(requests, method):
@@ -578,7 +681,7 @@ class DiscordClient:
r.text
)
if r.status_code == self._HTTP_STATUS_CODE_RATE_LIMITED:
if r.status_code == HTTPStatus.TOO_MANY_REQUESTS:
self._handle_new_api_backoff(r, uid)
self._report_rate_limit_from_api(r, uid)
@@ -589,9 +692,10 @@ class DiscordClient:
return r
def _handle_ongoing_api_backoff(self, uid: str) -> None:
"""checks if api is currently on backoff
if on backoff: will do a blocking wait if it expires soon,
else raises exception
"""Check if api is currently on backoff.
If on backoff: will do a blocking wait if it expires soon,
else raises exception.
"""
global_backoff_duration = self._redis.pttl(self._KEY_GLOBAL_BACKOFF_UNTIL)
if global_backoff_duration > 0:
@@ -611,8 +715,9 @@ class DiscordClient:
raise DiscordTooManyRequestsError(retry_after=global_backoff_duration)
def _ensure_rate_limed_not_exhausted(self, uid: str) -> int:
"""ensures that the rate limit is not exhausted
if exhausted: will do a blocking wait if rate limit resets soon,
"""Ensures that the rate limit is not exhausted.
If exhausted: will do a blocking wait if rate limit resets soon,
else raises exception
returns requests remaining on success
@@ -655,10 +760,10 @@ class DiscordClient:
)
raise DiscordRateLimitExhausted(resets_in)
raise RuntimeError('Failed to handle rate limit after after too tries.')
raise RuntimeError('Failed to handle rate limit after after too many tries.')
def _handle_new_api_backoff(self, r: requests.Response, uid: str) -> None:
"""raises exception for new API backoff error"""
"""Raise exception for new API backoff error."""
response = r.json()
if 'retry_after' in response:
try:
@@ -680,8 +785,8 @@ class DiscordClient:
)
raise DiscordTooManyRequestsError(retry_after=retry_after)
def _report_rate_limit_from_api(self, r, uid):
"""Tries to log the current rate limit reported from API"""
def _report_rate_limit_from_api(self, r, uid) -> None:
"""Try to log the current rate limit reported from API."""
if (
logger.getEffectiveLevel() <= logging.DEBUG
and 'x-ratelimit-limit' in r.headers
@@ -704,22 +809,17 @@ class DiscordClient:
@staticmethod
def _redis_decode(value: str) -> str:
"""Decodes a string from Redis and passes through None and Booleans"""
"""Decode a string from Redis and passes through None and Booleans."""
if value is not None and not isinstance(value, bool):
return value.decode('utf-8')
else:
return value
return value
@staticmethod
def _generate_hash(key: str) -> str:
"""Generate hash key for given string."""
return md5(key.encode('utf-8')).hexdigest()
@staticmethod
def _sanitize_role_ids(role_ids: list) -> list:
"""make sure its a list of integers"""
return [int(role_id) for role_id in list(role_ids)]
@classmethod
def _sanitize_nick(cls, nick: str) -> str:
"""shortens too long strings if necessary"""
return str(nick)[:cls._NICK_MAX_CHARS]
def _sanitize_role_ids(role_ids: Iterable[int]) -> List[int]:
"""Sanitize a list of role IDs, i.e. make sure its a list of unique integers."""
return [int(role_id) for role_id in set(role_ids)]

View File

@@ -1,23 +1,26 @@
"""Custom exceptions for the Discord Client package."""
import math
class DiscordClientException(Exception):
"""Base Exception for the Discord client"""
"""Base Exception for the Discord client."""
class DiscordApiBackoff(DiscordClientException):
"""Exception signaling we need to backoff from sending requests to the API for now
"""Exception signaling we need to backoff from sending requests to the API for now.
Args:
retry_after: time to retry after in milliseconds
"""
def __init__(self, retry_after: int):
"""
:param retry_after: int time to retry after in milliseconds
"""
super().__init__()
self.retry_after = int(retry_after)
@property
def retry_after_seconds(self):
"""Time to retry after in seconds."""
return math.ceil(self.retry_after / 1000)

View File

@@ -1,27 +1,37 @@
from copy import copy
from typing import Set, Iterable
from typing import Iterable, List, Optional, Set, Tuple
from .models import Role
class DiscordRoles:
"""Container class that helps dealing with Discord roles.
class RolesSet:
"""Container of Discord roles with added functionality.
Objects of this class are immutable and work in many ways like sets.
Ideally objects are initialized from raw API responses,
e.g. from DiscordClient.guild.roles()
"""
_ROLE_NAME_MAX_CHARS = 100
e.g. from DiscordClient.guild.roles().
def __init__(self, roles_lst: list) -> None:
"""roles_lst must be a list of dict, each defining a role"""
Args:
roles_lst: List of dicts, each defining a role
"""
def __init__(self, roles_lst: Iterable[Role]) -> None:
if not isinstance(roles_lst, (list, set, tuple)):
raise TypeError('roles_lst must be of type list, set or tuple')
self._roles = dict()
self._roles_by_name = dict()
for role in list(roles_lst):
self._assert_valid_role(role)
self._roles[int(role['id'])] = role
self._roles_by_name[self.sanitize_role_name(role['name'])] = role
if not isinstance(role, Role):
raise TypeError('Roles must be of type Role: %s' % role)
self._roles[role.id] = role
self._roles_by_name[role.name] = role
def __repr__(self) -> str:
if self._roles_by_name:
roles = '"' + '", "'.join(sorted(list(self._roles_by_name.keys()))) + '"'
else:
roles = ""
return f'{self.__class__.__name__}([{roles}])'
def __eq__(self, other):
if isinstance(other, type(self)):
@@ -41,15 +51,15 @@ class DiscordRoles:
return len(self._roles.keys())
def has_roles(self, role_ids: Set[int]) -> bool:
"""returns true if this objects contains all roles defined by given role_ids
incl. managed roles
"""True if this objects contains all roles defined by given role_ids
incl. managed roles.
"""
role_ids = {int(id) for id in role_ids}
all_role_ids = self._roles.keys()
return role_ids.issubset(all_role_ids)
def ids(self) -> Set[int]:
"""return a set of all role IDs"""
"""Set of all role IDs."""
return set(self._roles.keys())
def subset(
@@ -57,13 +67,13 @@ class DiscordRoles:
role_ids: Iterable[int] = None,
managed_only: bool = False,
role_names: Iterable[str] = None
) -> "DiscordRoles":
"""returns a new object containing the subset of roles
) -> "RolesSet":
"""Create instance containing the subset of roles
Args:
- role_ids: role ids must be in the provided list
- managed_only: roles must be managed
- role_names: role names must match provided list (not case sensitive)
role_ids: role ids must be in the provided list
managed_only: roles must be managed
role_names: role names must match provided list (not case sensitive)
"""
if role_ids is not None:
role_ids = {int(id) for id in role_ids}
@@ -75,72 +85,50 @@ class DiscordRoles:
elif role_ids is None and managed_only:
return type(self)([
role for _, role in self._roles.items() if role['managed']
role for _, role in self._roles.items() if role.managed
])
elif role_ids is not None and managed_only:
return type(self)([
role for role_id, role in self._roles.items()
if role_id in role_ids and role['managed']
if role_id in role_ids and role.managed
])
elif role_ids is None and managed_only is False and role_names is not None:
role_names = {self.sanitize_role_name(name).lower() for name in role_names}
role_names = {Role.sanitize_name(name).lower() for name in role_names}
return type(self)([
role for role in self._roles.values()
if role["name"].lower() in role_names
if role.name.lower() in role_names
])
return copy(self)
def union(self, other: object) -> "DiscordRoles":
"""returns a new roles object that is the union of this roles object
with other"""
def union(self, other: object) -> "RolesSet":
"""Create instance that is the union of this roles object with other."""
return type(self)(list(self) + list(other))
def difference(self, other: object) -> "DiscordRoles":
"""returns a new roles object that only contains the roles
that exist in the current objects, but not in other
def difference(self, other: object) -> "RolesSet":
"""Create instance that only contains the roles
that exist in the current objects, but not in other.
"""
new_ids = self.ids().difference(other.ids())
return self.subset(role_ids=new_ids)
def role_by_name(self, role_name: str) -> dict:
"""returns role if one with matching name is found else an empty dict"""
role_name = self.sanitize_role_name(role_name)
def role_by_name(self, role_name: str) -> Optional[Role]:
"""Role if one with matching name is found else None."""
role_name = Role.sanitize_name(role_name)
if role_name in self._roles_by_name:
return self._roles_by_name[role_name]
return dict()
return None
@classmethod
def create_from_matched_roles(cls, matched_roles: list) -> "DiscordRoles":
"""returns a new object created from the given list of matches roles
def create_from_matched_roles(
cls, matched_roles: List[Tuple[Role, bool]]
) -> "RolesSet":
"""Create new instance from the given list of matches roles.
matches_roles must be a list of tuples in the form: (role, created)
Args:
matches_roles: list of matches roles
"""
raw_roles = [x[0] for x in matched_roles]
return cls(raw_roles)
@staticmethod
def _assert_valid_role(role: dict) -> None:
if not isinstance(role, dict):
raise TypeError('Roles must be of type dict: %s' % role)
if 'id' not in role or 'name' not in role or 'managed' not in role:
raise ValueError('This role is not valid: %s' % role)
@classmethod
def sanitize_role_name(cls, role_name: str) -> str:
"""shortens too long strings if necessary"""
return str(role_name)[:cls._ROLE_NAME_MAX_CHARS]
def match_or_create_roles_from_names(
client: object, guild_id: int, role_names: list
) -> DiscordRoles:
"""Shortcut for getting the result of matching role names as DiscordRoles object"""
return DiscordRoles.create_from_matched_roles(
client.match_or_create_roles_from_names(
guild_id=guild_id, role_names=role_names
)
)

View File

@@ -0,0 +1,125 @@
"""Implementation of Discord objects used by this client.
Note that only those objects and properties are implemented, which are needed by AA.
Names and types are mirrored from the API whenever possible.
Discord's snowflake type (used by Discord IDs) is implemented as int.
"""
from dataclasses import asdict, dataclass
from typing import FrozenSet
@dataclass(frozen=True)
class User:
"""A user on Discord."""
id: int
username: str
discriminator: str
def __post_init__(self):
object.__setattr__(self, "id", int(self.id))
object.__setattr__(self, "username", str(self.username))
object.__setattr__(self, "discriminator", str(self.discriminator))
@classmethod
def from_dict(cls, data: dict) -> "User":
"""Create object from dictionary as received from the API."""
return cls(
id=int(data["id"]),
username=data["username"],
discriminator=data["discriminator"],
)
@dataclass(frozen=True)
class Role:
"""A role on Discord."""
_ROLE_NAME_MAX_CHARS = 100
id: int
name: str
managed: bool = False
def __post_init__(self):
object.__setattr__(self, "id", int(self.id))
object.__setattr__(self, "name", self.sanitize_name(self.name))
object.__setattr__(self, "managed", bool(self.managed))
def asdict(self) -> dict:
"""Convert object into a dictionary representation."""
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> "Role":
"""Create object from dictionary as received from the API."""
return cls(id=int(data["id"]), name=data["name"], managed=data["managed"])
@classmethod
def sanitize_name(cls, role_name: str) -> str:
"""Shorten too long names if necessary."""
return str(role_name)[: cls._ROLE_NAME_MAX_CHARS]
@dataclass(frozen=True)
class Guild:
"""A guild on Discord."""
id: int
name: str
roles: FrozenSet[Role]
def __post_init__(self):
object.__setattr__(self, "id", int(self.id))
object.__setattr__(self, "name", str(self.name))
for role in self.roles:
if not isinstance(role, Role):
raise TypeError("roles can only contain Role objects.")
object.__setattr__(self, "roles", frozenset(self.roles))
@classmethod
def from_dict(cls, data: dict) -> "Guild":
"""Create object from dictionary as received from the API."""
return cls(
id=int(data["id"]),
name=data["name"],
roles=frozenset(Role.from_dict(obj) for obj in data["roles"]),
)
@dataclass(frozen=True)
class GuildMember:
"""A member of a guild on Discord."""
_NICK_MAX_CHARS = 32
roles: FrozenSet[int]
nick: str = None
user: User = None
def __post_init__(self):
if self.nick:
object.__setattr__(self, "nick", self.sanitize_nick(self.nick))
if self.user and not isinstance(self.user, User):
raise TypeError("user must be of type User")
for role in self.roles:
if not isinstance(role, int):
raise TypeError("roles can only contain ints")
object.__setattr__(self, "roles", frozenset(self.roles))
@classmethod
def from_dict(cls, data: dict) -> "GuildMember":
"""Create object from dictionary as received from the API."""
params = {"roles": {int(obj) for obj in data["roles"]}}
if data.get("user"):
params["user"] = User.from_dict(data["user"])
if data.get("nick"):
params["nick"] = data["nick"]
return cls(**params)
@classmethod
def sanitize_nick(cls, nick: str) -> str:
"""Sanitize a nick, i.e. shorten too long strings if necessary."""
return str(nick)[: cls._NICK_MAX_CHARS]

View File

@@ -1,40 +0,0 @@
TEST_GUILD_ID = 123456789012345678
TEST_USER_ID = 198765432012345678
TEST_USER_NAME = 'Peter Parker'
TEST_USER_DISCRIMINATOR = '1234'
TEST_BOT_TOKEN = 'abcdefhijlkmnopqastzvwxyz1234567890ABCDEFGHOJKLMNOPQRSTUVWXY'
TEST_ROLE_ID = 654321012345678912
def create_role(id: int, name: str, managed=False) -> dict:
return {
'id': int(id),
'name': str(name),
'managed': bool(managed)
}
def create_matched_role(role, created=False) -> tuple:
return role, created
ROLE_ALPHA = create_role(1, 'alpha')
ROLE_BRAVO = create_role(2, 'bravo')
ROLE_CHARLIE = create_role(3, 'charlie')
ROLE_CHARLIE_2 = create_role(4, 'Charlie') # Discord roles are case sensitive
ROLE_MIKE = create_role(13, 'mike', True)
ALL_ROLES = [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
def create_user_info(
id: int = TEST_USER_ID,
username: str = TEST_USER_NAME,
discriminator: str = TEST_USER_DISCRIMINATOR
):
return {
'id': str(id),
'username': str(username[:32]),
'discriminator': str(discriminator[:4])
}

View File

@@ -0,0 +1,155 @@
{
"guilds": {
"2909267986263572999": {
"id": "2909267986263572999",
"name": "Mason's Test Server",
"icon": "389030ec9db118cb5b85a732333b7c98",
"description": null,
"splash": "75610b05a0dd09ec2c3c7df9f6975ea0",
"discovery_splash": null,
"approximate_member_count": 2,
"approximate_presence_count": 2,
"features": [
"INVITE_SPLASH",
"VANITY_URL",
"COMMERCE",
"BANNER",
"NEWS",
"VERIFIED",
"VIP_REGIONS"
],
"emojis": [
{
"name": "ultrafastparrot",
"roles": [],
"id": "393564762228785161",
"require_colons": true,
"managed": false,
"animated": true,
"available": true
}
],
"banner": "5c3cb8d1bc159937fffe7e641ec96ca7",
"owner_id": "53908232506183680",
"application_id": null,
"region": null,
"afk_channel_id": null,
"afk_timeout": 300,
"system_channel_id": null,
"widget_enabled": true,
"widget_channel_id": "639513352485470208",
"verification_level": 0,
"roles": [
{
"id": "2909267986263572999",
"name": "@everyone",
"permissions": "49794752",
"position": 0,
"color": 0,
"hoist": false,
"managed": false,
"mentionable": false
}
],
"default_message_notifications": 1,
"mfa_level": 0,
"explicit_content_filter": 0,
"max_presences": null,
"max_members": 250000,
"max_video_channel_users": 25,
"vanity_url_code": "no",
"premium_tier": 0,
"premium_subscription_count": 0,
"system_channel_flags": 0,
"preferred_locale": "en-US",
"rules_channel_id": null,
"public_updates_channel_id": null
}
},
"guildMembers": {
"1": {
"user": {},
"nick": null,
"avatar": null,
"roles": [],
"joined_at": "2015-04-26T06:26:56.936000+00:00",
"deaf": false,
"mute": false
},
"2": {
"user": {
"id": "80351110224678912",
"username": "Nelly",
"discriminator": "1337",
"avatar": "8342729096ea3675442027381ff50dfe",
"verified": true,
"email": "nelly@discord.com",
"flags": 64,
"banner": "06c16474723fe537c283b8efa61a30c8",
"accent_color": 16711680,
"premium_type": 1,
"public_flags": 64
},
"nick": "Nelly the great",
"avatar": null,
"roles": [
"197150972374548480",
"41771983423143936"
],
"joined_at": "2015-04-26T06:26:56.936000+00:00",
"deaf": false,
"mute": false
}
},
"roles": {
"197150972374548480": {
"id": "197150972374548480",
"name": "My Managed Role",
"color": 3447003,
"hoist": false,
"icon": "cf3ced8600b777c9486c6d8d84fb4327",
"unicode_emoji": null,
"position": 2,
"permissions": "66321471",
"managed": true,
"mentionable": false
},
"2909267986263572999": {
"id": "2909267986263572999",
"name": "@everyone",
"permissions": "49794752",
"position": 0,
"color": 0,
"hoist": false,
"managed": false,
"mentionable": false
},
"41771983423143936": {
"id": "41771983423143936",
"name": "WE DEM BOYZZ!!!!!!",
"color": 3447003,
"hoist": true,
"icon": "cf3ced8600b777c9486c6d8d84fb4327",
"unicode_emoji": null,
"position": 1,
"permissions": "66321471",
"managed": false,
"mentionable": false
}
},
"users": {
"80351110224678912": {
"id": "80351110224678912",
"username": "Nelly",
"discriminator": "1337",
"avatar": "8342729096ea3675442027381ff50dfe",
"verified": true,
"email": "nelly@discord.com",
"flags": 64,
"banner": "06c16474723fe537c283b8efa61a30c8",
"accent_color": 16711680,
"premium_type": 1,
"public_flags": 64
}
}
}

View File

@@ -0,0 +1,110 @@
from itertools import count
from django.utils.timezone import now
from ..client import DiscordApiStatusCode
from ..models import Guild, GuildMember, Role, User
TEST_GUILD_ID = 123456789012345678
TEST_GUILD_NAME = "Test Guild"
TEST_USER_ID = 198765432012345678
TEST_USER_NAME = "Peter Parker"
TEST_USER_DISCRIMINATOR = "1234"
TEST_BOT_TOKEN = "abcdefhijlkmnopqastzvwxyz1234567890ABCDEFGHOJKLMNOPQRSTUVWXY"
TEST_ROLE_ID = 654321012345678912
def create_discord_role_object(id: int, name: str, managed: bool = False) -> dict:
return {"id": str(int(id)), "name": str(name), "managed": bool(managed)}
def create_matched_role(role, created=False) -> tuple:
return role, created
def create_discord_user_object(**kwargs):
params = {
"id": TEST_USER_ID,
"username": TEST_USER_NAME,
"discriminator": TEST_USER_DISCRIMINATOR,
}
params.update(kwargs)
params["id"] = str(int(params["id"]))
return params
def create_discord_guild_member_object(user=None, **kwargs):
user_params = {}
if user:
user_params["user"] = user
params = {
"user": create_discord_user_object(**user_params),
"roles": [],
"joined_at": now().isoformat(),
"deaf": False,
"mute": False,
}
params.update(kwargs)
params["roles"] = [str(int(obj)) for obj in params["roles"]]
return params
def create_discord_error_response(code: int) -> dict:
return {"code": int(code)}
def create_discord_error_response_unknown_member() -> dict:
return create_discord_error_response(DiscordApiStatusCode.UNKNOWN_MEMBER.value)
def create_discord_guild_object(**kwargs):
params = {"id": TEST_GUILD_ID, "name": TEST_GUILD_NAME, "roles": []}
params.update(kwargs)
params["id"] = str(int(params["id"]))
return params
def create_user(**kwargs):
params = {
"id": TEST_USER_ID,
"username": TEST_USER_NAME,
"discriminator": TEST_USER_DISCRIMINATOR,
}
params.update(kwargs)
return User(**params)
def create_guild(**kwargs):
params = {"id": TEST_GUILD_ID, "name": TEST_GUILD_NAME, "roles": []}
params.update(kwargs)
return Guild(**params)
def create_guild_member(**kwargs):
params = {"user": create_user(), "roles": []}
params.update(kwargs)
return GuildMember(**params)
def create_role(**kwargs) -> dict:
params = {"managed": False}
params.update(kwargs)
if "id" not in params:
params["id"] = next_number("role")
if "name" not in params:
params["name"] = f"Test Role #{params['id']}"
return Role(**params)
def next_number(key: str = None) -> int:
"""Calculate the next number in a persistent sequence."""
if key is None:
key = "_general"
try:
return next_number._counter[key].__next__()
except AttributeError:
next_number._counter = dict()
except KeyError:
pass
next_number._counter[key] = count(start=1)
return next_number._counter[key].__next__()

View File

@@ -1,177 +1,201 @@
from unittest import TestCase
from allianceauth.utils.testing import NoSocketsTestCase
from . import (
ROLE_ALPHA,
ROLE_BRAVO,
ROLE_CHARLIE,
ROLE_CHARLIE_2,
ROLE_MIKE,
ALL_ROLES,
create_role
)
from .. import DiscordRoles
from ..helpers import RolesSet
from .factories import create_matched_role, create_role
MODULE_PATH = 'allianceauth.services.modules.discord.discord_client.client'
class TestDiscordRoles(TestCase):
def setUp(self):
self.all_roles = DiscordRoles(ALL_ROLES)
class TestRolesSet(NoSocketsTestCase):
def test_can_create_simple(self):
roles_raw = [ROLE_ALPHA]
roles = DiscordRoles(roles_raw)
# given
roles_raw = [create_role()]
# when
roles = RolesSet(roles_raw)
# then
self.assertListEqual(list(roles), roles_raw)
def test_can_create_empty(self):
roles_raw = []
roles = DiscordRoles(roles_raw)
# when
roles = RolesSet([])
# then
self.assertListEqual(list(roles), [])
def test_raises_exception_if_roles_raw_of_wrong_type(self):
with self.assertRaises(TypeError):
DiscordRoles({'id': 1})
RolesSet({"id": 1})
def test_raises_exception_if_list_contains_non_dict(self):
roles_raw = [ROLE_ALPHA, 'not_valid']
# given
roles_raw = [create_role(), "not_valid"]
# when/then
with self.assertRaises(TypeError):
DiscordRoles(roles_raw)
def test_raises_exception_if_invalid_role_1(self):
roles_raw = [{'name': 'alpha', 'managed': False}]
with self.assertRaises(ValueError):
DiscordRoles(roles_raw)
def test_raises_exception_if_invalid_role_2(self):
roles_raw = [{'id': 1, 'managed': False}]
with self.assertRaises(ValueError):
DiscordRoles(roles_raw)
def test_raises_exception_if_invalid_role_3(self):
roles_raw = [{'id': 1, 'name': 'alpha'}]
with self.assertRaises(ValueError):
DiscordRoles(roles_raw)
RolesSet(roles_raw)
def test_roles_are_equal(self):
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_b = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
# given
role_a = create_role()
role_b = create_role()
roles_a = RolesSet([role_a, role_b])
roles_b = RolesSet([role_a, role_b])
# when/then
self.assertEqual(roles_a, roles_b)
def test_roles_are_not_equal(self):
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_b = DiscordRoles([ROLE_ALPHA])
# given
role_a = create_role()
role_b = create_role()
roles_a = RolesSet([role_a, role_b])
roles_b = RolesSet([role_a])
# when/then
self.assertNotEqual(roles_a, roles_b)
def test_different_objects_are_not_equal(self):
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_a = RolesSet([])
self.assertFalse(roles_a == "invalid")
def test_len(self):
self.assertEqual(len(self.all_roles), 4)
# given
role_a = create_role()
role_b = create_role()
roles = RolesSet([role_a, role_b])
# when/then
self.assertEqual(len(roles), 2)
def test_contains(self):
self.assertTrue(1 in self.all_roles)
self.assertFalse(99 in self.all_roles)
def test_sanitize_role_name(self):
role_name_input = 'x' * 110
role_name_expected = 'x' * 100
result = DiscordRoles.sanitize_role_name(role_name_input)
self.assertEqual(result, role_name_expected)
# given
role_a = create_role(id=1)
roles = RolesSet([role_a])
# when/then
self.assertTrue(1 in roles)
self.assertFalse(99 in roles)
def test_objects_are_hashable(self):
roles_a = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_b = DiscordRoles([ROLE_BRAVO, ROLE_ALPHA])
roles_c = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE])
# given
role_a = create_role()
role_b = create_role()
role_c = create_role()
roles_a = RolesSet([role_a, role_b])
roles_b = RolesSet([role_b, role_a])
roles_c = RolesSet([role_a, role_b, role_c])
# when/then
self.assertIsNotNone(hash(roles_a))
self.assertEqual(hash(roles_a), hash(roles_b))
self.assertNotEqual(hash(roles_a), hash(roles_c))
def test_create_from_matched_roles(self):
role_a = create_role()
role_b = create_role()
matched_roles = [
(ROLE_ALPHA, True),
(ROLE_BRAVO, False)
create_matched_role(role_a, True),
create_matched_role(role_b, False),
]
roles = DiscordRoles.create_from_matched_roles(matched_roles)
self.assertSetEqual(roles.ids(), {1, 2})
class TestIds(TestCase):
def setUp(self):
self.all_roles = DiscordRoles(ALL_ROLES)
# when
roles = RolesSet.create_from_matched_roles(matched_roles)
# then
self.assertEqual(roles, RolesSet([role_a, role_b]))
def test_return_role_ids_default(self):
result = self.all_roles.ids()
expected = {1, 2, 3, 13}
self.assertSetEqual(result, expected)
role_a = create_role(id=1)
role_b = create_role(id=2)
roles = RolesSet([role_a, role_b])
# when/then
self.assertSetEqual(roles.ids(), {1, 2})
def test_return_role_ids_empty(self):
roles = DiscordRoles([])
# given
roles = RolesSet([])
# when/then
self.assertSetEqual(roles.ids(), set())
class TestSubset(TestCase):
def setUp(self):
self.all_roles = DiscordRoles(ALL_ROLES)
class TestRolesSetSubset(NoSocketsTestCase):
def test_ids_only(self):
role_ids = {1, 3}
roles_subset = self.all_roles.subset(role_ids)
expected = {1, 3}
self.assertSetEqual(roles_subset.ids(), expected)
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
role_c = create_role(id=3)
roles_all = RolesSet([role_a, role_b, role_c])
# when
roles_subset = roles_all.subset({1, 3})
# then
self.assertEqual(roles_subset, RolesSet([role_a, role_c]))
def test_ids_as_string_work_too(self):
role_ids = {'1', '3'}
roles_subset = self.all_roles.subset(role_ids)
expected = {1, 3}
self.assertSetEqual(roles_subset.ids(), expected)
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
role_c = create_role(id=3)
roles_all = RolesSet([role_a, role_b, role_c])
# when
roles_subset = roles_all.subset({"1", "3"})
# then
self.assertEqual(roles_subset, RolesSet([role_a, role_c]))
def test_managed_only(self):
roles = self.all_roles.subset(managed_only=True)
expected = {13}
self.assertSetEqual(roles.ids(), expected)
# given
role_a = create_role(id=1)
role_m = create_role(id=13, managed=True)
roles_all = RolesSet([role_a, role_m])
# when
roles_subset = roles_all.subset(managed_only=True)
# then
self.assertEqual(roles_subset, RolesSet([role_m]))
def test_ids_and_managed_only(self):
role_ids = {1, 3, 13}
roles_subset = self.all_roles.subset(role_ids, managed_only=True)
expected = {13}
self.assertSetEqual(roles_subset.ids(), expected)
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
role_m = create_role(id=13, managed=True)
roles_all = RolesSet([role_a, role_b, role_m])
# when
roles_subset = roles_all.subset({1, 13}, managed_only=True)
# then
self.assertEqual(roles_subset, RolesSet([role_m]))
def test_ids_are_empty(self):
roles = self.all_roles.subset([])
expected = set()
self.assertSetEqual(roles.ids(), expected)
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
roles_all = RolesSet([role_a, role_b])
roles_subset = roles_all.subset([])
# then
self.assertEqual(roles_subset, RolesSet([]))
def test_no_parameters(self):
roles = self.all_roles.subset()
expected = {1, 2, 3, 13}
self.assertSetEqual(roles.ids(), expected)
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
roles_all = RolesSet([role_a, role_b])
roles_subset = roles_all.subset()
# then
self.assertEqual(roles_subset, roles_all)
def test_should_return_role_names_only(self):
# given
all_roles = DiscordRoles([
ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE, ROLE_CHARLIE_2
])
role_a = create_role(name="alpha")
role_b = create_role(name="bravo")
role_c1 = create_role(name="charlie")
role_c2 = create_role(name="Charlie")
roles_all = RolesSet([role_a, role_b, role_c1, role_c2])
# when
roles = all_roles.subset(role_names={"bravo", "charlie"})
roles_subset = roles_all.subset(role_names={"bravo", "charlie"})
# then
self.assertSetEqual(roles.ids(), {2, 3, 4})
self.assertSetEqual(roles_subset, RolesSet([role_b, role_c1, role_c2]))
class TestHasRoles(TestCase):
def setUp(self):
self.all_roles = DiscordRoles(ALL_ROLES)
class TestRolesSetHasRoles(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
role_a = create_role(id=1)
role_b = create_role(id=2)
role_c = create_role(id=3)
cls.all_roles = RolesSet([role_a, role_b, role_c])
def test_true_if_all_roles_exit(self):
self.assertTrue(self.all_roles.has_roles([1, 2]))
def test_true_if_all_roles_exit_str(self):
self.assertTrue(self.all_roles.has_roles(['1', '2']))
self.assertTrue(self.all_roles.has_roles(["1", "2"]))
def test_false_if_role_does_not_exit(self):
self.assertFalse(self.all_roles.has_roles([99]))
@@ -183,74 +207,104 @@ class TestHasRoles(TestCase):
self.assertTrue(self.all_roles.has_roles([]))
class TestGetMatchingRolesByName(TestCase):
def setUp(self):
self.all_roles = DiscordRoles(ALL_ROLES)
class TestRolesSetGetMatchingRolesByName(NoSocketsTestCase):
def test_return_role_if_matches(self):
role_name = 'alpha'
expected = ROLE_ALPHA
result = self.all_roles.role_by_name(role_name)
self.assertEqual(result, expected)
# given
role_a = create_role(name="alpha")
role_b = create_role(name="bravo")
roles = RolesSet([role_a, role_b])
# when
result = roles.role_by_name("alpha")
# then
self.assertEqual(result, role_a)
def test_return_role_if_matches_and_limit_max_length(self):
role_name = 'x' * 120
expected = create_role(77, 'x' * 100)
roles = DiscordRoles([expected])
# given
role_name = "x" * 120
role = create_role(name="x" * 100)
roles = RolesSet([role])
# when
result = roles.role_by_name(role_name)
self.assertEqual(result, expected)
# then
self.assertEqual(result, role)
def test_return_empty_if_not_matches(self):
role_name = 'lima'
expected = {}
result = self.all_roles.role_by_name(role_name)
self.assertEqual(result, expected)
# given
role_a = create_role(name="alpha")
role_b = create_role(name="bravo")
roles = RolesSet([role_a, role_b])
# when
result = roles.role_by_name("unknown")
# then
self.assertIsNone(result)
class TestUnion(TestCase):
class TestRolesSetUnion(NoSocketsTestCase):
def test_distinct_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_CHARLIE, ROLE_MIKE])
roles_3 = roles_1.union(roles_2)
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE])
self.assertEqual(roles_3, expected)
# given
role_a = create_role()
role_b = create_role()
roles_1 = RolesSet([role_a])
roles_2 = RolesSet([role_b])
# when
result = roles_1.union(roles_2)
# then
self.assertEqual(result, RolesSet([role_a, role_b]))
def test_overlapping_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_BRAVO, ROLE_MIKE])
roles_3 = roles_1.union(roles_2)
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE])
self.assertEqual(roles_3, expected)
# given
role_a = create_role()
role_b = create_role()
role_c = create_role()
roles_1 = RolesSet([role_a, role_b])
roles_2 = RolesSet([role_b, role_c])
# when
result = roles_1.union(roles_2)
self.assertEqual(result, RolesSet([role_a, role_b, role_c]))
def test_identical_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_3 = roles_1.union(roles_2)
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
self.assertEqual(roles_3, expected)
role_a = create_role()
role_b = create_role()
roles_1 = RolesSet([role_a, role_b])
roles_2 = RolesSet([role_a, role_b])
# when
result = roles_1.union(roles_2)
self.assertEqual(result, RolesSet([role_a, role_b]))
class TestDifference(TestCase):
class TestRolesSetDifference(NoSocketsTestCase):
def test_distinct_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_CHARLIE, ROLE_MIKE])
roles_3 = roles_1.difference(roles_2)
expected = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
self.assertEqual(roles_3, expected)
# given
role_a = create_role()
role_b = create_role()
role_c = create_role()
role_d = create_role()
roles_1 = RolesSet([role_a, role_b])
roles_2 = RolesSet([role_c, role_d])
# when
result = roles_1.difference(roles_2)
# then
self.assertEqual(result, RolesSet([role_a, role_b]))
def test_overlapping_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_BRAVO, ROLE_MIKE])
roles_3 = roles_1.difference(roles_2)
expected = DiscordRoles([ROLE_ALPHA])
self.assertEqual(roles_3, expected)
# given
role_a = create_role()
role_b = create_role()
role_c = create_role()
roles_1 = RolesSet([role_a, role_b])
roles_2 = RolesSet([role_b, role_c])
# when
result = roles_1.difference(roles_2)
# then
self.assertEqual(result, RolesSet([role_a]))
def test_identical_sets(self):
roles_1 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_2 = DiscordRoles([ROLE_ALPHA, ROLE_BRAVO])
roles_3 = roles_1.difference(roles_2)
expected = DiscordRoles([])
self.assertEqual(roles_3, expected)
# given
role_a = create_role()
role_b = create_role()
roles_1 = RolesSet([role_a, role_b])
roles_2 = RolesSet([role_a, role_b])
# when
result = roles_1.difference(roles_2)
# then
self.assertEqual(result, RolesSet([]))

View File

@@ -0,0 +1,156 @@
import json
from pathlib import Path
from unittest import TestCase
from ..models import Guild, GuildMember, Role, User
from .factories import create_guild, create_guild_member, create_role, create_user
def _fetch_example_objects() -> dict:
path = Path(__file__).parent / "example_objects.json"
with path.open("r", encoding="utf-8") as fp:
return json.load(fp)
class TestUser(TestCase):
def test_should_create_new_object(self):
# when
obj = User(id="42", username=123, discriminator=456)
# then
self.assertEqual(obj.id, 42)
self.assertEqual(obj.username, "123")
self.assertTrue(obj.discriminator, "456")
def test_should_create_from_dict(self):
# given
data = example_objects["users"]["80351110224678912"]
# when
obj = User.from_dict(data)
# then
self.assertEqual(obj.id, 80351110224678912)
self.assertEqual(obj.username, "Nelly")
self.assertEqual(obj.discriminator, "1337")
class TestRole(TestCase):
def test_should_create_new_object_with_defaults(self):
# when
obj = Role(id="42", name="x" * 110)
# then
self.assertEqual(obj.id, 42)
self.assertEqual(obj.name, "x" * 100)
self.assertFalse(obj.managed)
def test_should_create_new_object(self):
# when
obj = Role(id=42, name="name", managed=1)
# then
self.assertEqual(obj.id, 42)
self.assertEqual(obj.name, "name")
self.assertTrue(obj.managed)
def test_should_create_from_dict(self):
# given
data = example_objects["roles"]["41771983423143936"]
# when
obj = Role.from_dict(data)
# then
self.assertEqual(obj.id, 41771983423143936)
self.assertEqual(obj.name, "WE DEM BOYZZ!!!!!!")
self.assertFalse(obj.managed)
def test_should_convert_to_dict(self):
# given
role = create_role(id=42, name="Special Name", managed=True)
# when/then
self.assertDictEqual(
role.asdict(), {"id": 42, "name": "Special Name", "managed": True}
)
def test_sanitize_role_name(self):
# given
role_name_input = "x" * 110
role_name_expected = "x" * 100
# when
result = Role.sanitize_name(role_name_input)
# then
self.assertEqual(result, role_name_expected)
class TestGuild(TestCase):
def test_should_create_new_object(self):
# given
role_a = create_role()
# when
obj = Guild(id="42", name=123, roles=[role_a])
# then
self.assertEqual(obj.id, 42)
self.assertEqual(obj.name, "123")
self.assertEqual(obj.roles, frozenset([role_a]))
def test_should_create_from_dict(self):
# given
data = example_objects["guilds"]["2909267986263572999"]
# when
obj = Guild.from_dict(data)
# then
self.assertEqual(obj.id, 2909267986263572999)
self.assertEqual(obj.name, "Mason's Test Server")
(first_role,) = obj.roles
self.assertEqual(first_role.id, 2909267986263572999)
def test_should_raise_error_when_role_type_is_wrong(self):
with self.assertRaises(TypeError):
create_guild(roles=[create_role(), "invalid"])
class TestGuildMember(TestCase):
def test_should_create_new_object(self):
# given
user = create_user()
# when
obj = GuildMember(user=user, nick="x" * 40, roles=[1, 2])
# then
self.assertEqual(obj.user, user)
self.assertEqual(obj.nick, "x" * 32)
self.assertEqual(obj.roles, frozenset([1, 2]))
def test_should_create_from_dict_empty(self):
# given
data = example_objects["guildMembers"]["1"]
# when
obj = GuildMember.from_dict(data)
# then
self.assertIsNone(obj.user)
self.assertSetEqual(obj.roles, set())
self.assertIsNone(obj.nick)
def test_should_create_from_dict_full(self):
# given
data = example_objects["guildMembers"]["2"]
# when
obj = GuildMember.from_dict(data)
# then
self.assertEqual(obj.user.username, "Nelly")
self.assertSetEqual(obj.roles, {197150972374548480, 41771983423143936})
self.assertEqual(obj.nick, "Nelly the great")
def test_should_raise_error_when_user_type_is_wrong(self):
with self.assertRaises(TypeError):
create_guild_member(user="invalid")
def test_should_raise_error_when_role_type_is_wrong(self):
with self.assertRaises(TypeError):
GuildMember(roles=[1, 2, "invalid"])
def test_sanitize_nick(self):
# given
nick_input = "x" * 40
nick_expected = "x" * 32
# when
result = GuildMember.sanitize_nick(nick_input)
# then
self.assertEqual(result, nick_expected)
example_objects = _fetch_example_objects()

View File

@@ -1,30 +1,33 @@
import logging
from urllib.parse import urlencode
from requests_oauthlib import OAuth2Session
from requests.exceptions import HTTPError
from requests_oauthlib import OAuth2Session
from django.contrib.auth.models import User, Group
from django.contrib.auth.models import Group, User
from django.db import models
from django.utils.timezone import now
from allianceauth.services.hooks import NameFormatter
from . import __title__
from .app_settings import (
DISCORD_APP_ID,
DISCORD_APP_SECRET,
DISCORD_BOT_TOKEN,
DISCORD_CALLBACK_URL,
DISCORD_GUILD_ID,
DISCORD_SYNC_NAMES
DISCORD_SYNC_NAMES,
)
from .core import calculate_roles_for_user, create_bot_client
from .core import group_to_role as core_group_to_role
from .core import server_name as core_server_name
from .core import user_formatted_nick
from .discord_client import (
DISCORD_OAUTH_BASE_URL,
DISCORD_OAUTH_TOKEN_URL,
DiscordApiBackoff,
DiscordClient,
)
from .discord_client import DiscordClient
from .discord_client.exceptions import DiscordClientException, DiscordApiBackoff
from .discord_client.helpers import match_or_create_roles_from_names
from .utils import LoggerAddTag
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
@@ -56,79 +59,68 @@ class DiscordUserManager(models.Manager):
Returns: True on success, else False or raises exception
"""
try:
nickname = self.user_formatted_nick(user) if DISCORD_SYNC_NAMES else None
group_names = self.user_group_names(user)
nickname = user_formatted_nick(user) if DISCORD_SYNC_NAMES else None
access_token = self._exchange_auth_code_for_token(authorization_code)
user_client = DiscordClient(access_token, is_rate_limited=is_rate_limited)
discord_user = user_client.current_user()
user_id = discord_user['id']
bot_client = self._bot_client(is_rate_limited=is_rate_limited)
if group_names:
role_ids = match_or_create_roles_from_names(
client=bot_client,
guild_id=DISCORD_GUILD_ID,
role_names=group_names
).ids()
else:
role_ids = None
created = bot_client.add_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=user_id,
access_token=access_token,
role_ids=role_ids,
nick=nickname
bot_client = create_bot_client(is_rate_limited=is_rate_limited)
roles, changed = calculate_roles_for_user(
user=user, client=bot_client, discord_uid=discord_user.id
)
if created is not False:
if created is None:
logger.debug(
"User %s with Discord ID %s is already a member. Forcing a Refresh",
if changed is None:
# Handle new member
created = bot_client.add_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=discord_user.id,
access_token=access_token,
role_ids=list(roles.ids()),
nick=nickname
)
if not created:
logger.warning(
"Failed to add user %s with Discord ID %s to Discord server",
user,
user_id,
discord_user.id,
)
# Force an update cause the discord API won't do it for us.
if role_ids:
role_ids = list(role_ids)
updated = bot_client.modify_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=user_id,
role_ids=role_ids,
nick=nickname
)
if not updated:
# Could not update the new user so fail.
logger.warning(
"Failed to add user %s with Discord ID %s to Discord server",
user,
user_id,
)
return False
self.update_or_create(
user=user,
defaults={
'uid': user_id,
'username': discord_user['username'][:32],
'discriminator': discord_user['discriminator'][:4],
'activated': now()
}
)
logger.info(
"Added user %s with Discord ID %s to Discord server", user, user_id
)
return True
return False
else:
logger.warning(
"Failed to add user %s with Discord ID %s to Discord server",
# Handle existing member
logger.debug(
"User %s with Discord ID %s is already a member. Forcing a Refresh",
user,
user_id,
discord_user.id,
)
return False
# Force an update cause the discord API won't do it for us.
updated = bot_client.modify_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=discord_user.id,
role_ids=list(roles.ids()),
nick=nickname
)
if not updated:
# Could not update the new user so fail.
logger.warning(
"Failed to add user %s with Discord ID %s to Discord server",
user,
discord_user.id,
)
return False
self.update_or_create(
user=user,
defaults={
'uid': discord_user.id,
'username': discord_user.username[:32],
'discriminator': discord_user.discriminator[:4],
'activated': now()
}
)
logger.info(
"Added user %s with Discord ID %s to Discord server",
user,
discord_user.id
)
return True
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
logger.exception(
@@ -136,31 +128,6 @@ class DiscordUserManager(models.Manager):
)
return False
@staticmethod
def user_formatted_nick(user: User) -> str:
"""returns the name of the given users main character with name formatting
or None if user has no main
"""
from .auth_hooks import DiscordService
if user.profile.main_character:
return NameFormatter(DiscordService(), user).format_name()
else:
return None
@staticmethod
def user_group_names(user: User, state_name: str = None) -> list:
"""returns list of group names plus state the given user is a member of"""
if not state_name:
state_name = user.profile.state.name
group_names = (
[group.name for group in user.groups.all()] + [state_name]
)
logger.debug(
"Group names for roles updates of user %s are: %s", user, group_names
)
return group_names
def user_has_account(self, user: User) -> bool:
"""Returns True if the user has an Discord account, else False
@@ -178,60 +145,41 @@ class DiscordUserManager(models.Manager):
'permissions': str(cls.BOT_PERMISSIONS)
})
return f'{DiscordClient.OAUTH_BASE_URL}?{params}'
return f'{DISCORD_OAUTH_BASE_URL}?{params}'
@classmethod
def generate_oauth_redirect_url(cls) -> str:
oauth = OAuth2Session(
DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL, scope=cls.SCOPES
)
url, state = oauth.authorization_url(DiscordClient.OAUTH_BASE_URL)
url, _ = oauth.authorization_url(DISCORD_OAUTH_BASE_URL)
return url
@staticmethod
def _exchange_auth_code_for_token(authorization_code: str) -> str:
oauth = OAuth2Session(DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL)
token = oauth.fetch_token(
DiscordClient.OAUTH_TOKEN_URL,
DISCORD_OAUTH_TOKEN_URL,
client_secret=DISCORD_APP_SECRET,
code=authorization_code
)
logger.debug("Received token from OAuth")
return token['access_token']
@classmethod
def server_name(cls, use_cache: bool = True) -> str:
"""returns the name of the current Discord server
or an empty string if the name could not be retrieved
@staticmethod
def group_to_role(group: Group) -> dict:
"""Fetch the Discord role matching the given Django group by name.
Params:
- use_cache: When set False will force an API call to get the server name
Returns:
- Discord role as dict
- empty dict if no matching role found
"""
try:
server_name = cls._bot_client().guild_name(
guild_id=DISCORD_GUILD_ID, use_cache=use_cache
)
except (HTTPError, DiscordClientException):
server_name = ""
except Exception:
logger.warning(
"Unexpected error when trying to retrieve the server name from Discord",
exc_info=True
)
server_name = ""
return server_name
@classmethod
def group_to_role(cls, group: Group) -> dict:
"""returns the Discord role matching the given Django group by name
or an empty dict() if no matching role exist
"""
return cls._bot_client().match_role_from_name(
guild_id=DISCORD_GUILD_ID, role_name=group.name
)
role = core_group_to_role(group)
return role.asdict() if role else dict()
@staticmethod
def _bot_client(is_rate_limited: bool = True) -> DiscordClient:
"""returns a bot client for access to the Discord API"""
return DiscordClient(DISCORD_BOT_TOKEN, is_rate_limited=is_rate_limited)
def server_name(use_cache: bool = True) -> str:
"""Fetches the name of the current Discord server.
This method is kept to ensure backwards compatibility of this API.
"""
return core_server_name(use_cache)

View File

@@ -1,4 +1,5 @@
import logging
from typing import Optional
from requests.exceptions import HTTPError
@@ -6,13 +7,17 @@ from django.contrib.auth.models import User
from django.db import models
from django.utils.translation import gettext_lazy
from allianceauth.groupmanagement.models import ReservedGroupName
from allianceauth.notifications import notify
from . import __title__
from .app_settings import DISCORD_GUILD_ID
from .discord_client import DiscordApiBackoff, DiscordClient, DiscordRoles
from .discord_client.helpers import match_or_create_roles_from_names
from .core import (
create_bot_client,
default_bot_client,
calculate_roles_for_user,
user_formatted_nick
)
from .discord_client import DiscordApiBackoff
from .managers import DiscordUserManager
from .utils import LoggerAddTag
@@ -21,14 +26,13 @@ logger = LoggerAddTag(logging.getLogger(__name__), __title__)
class DiscordUser(models.Model):
USER_RELATED_NAME = 'discord'
"""The Discord user account of an Auth user."""
user = models.OneToOneField(
User,
primary_key=True,
on_delete=models.CASCADE,
related_name=USER_RELATED_NAME,
related_name='discord',
help_text='Auth user owning this Discord account'
)
uid = models.BigIntegerField(
@@ -80,24 +84,21 @@ class DiscordUser(models.Model):
- False on error or raises exception
"""
if not nickname:
nickname = DiscordUser.objects.user_formatted_nick(self.user)
if nickname:
client = DiscordUser.objects._bot_client()
success = client.modify_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=self.uid,
nick=nickname
)
if success:
logger.info('Nickname for %s has been updated', self.user)
else:
logger.warning('Failed to update nickname for %s', self.user)
return success
else:
nickname = user_formatted_nick(self.user)
if not nickname:
return False
success = default_bot_client.modify_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=self.uid,
nick=nickname
)
if success:
logger.info('Nickname for %s has been updated', self.user)
else:
logger.warning('Failed to update nickname for %s', self.user)
return success
def update_groups(self, state_name: str = None) -> bool:
def update_groups(self, state_name: str = None) -> Optional[bool]:
"""update groups for a user based on his current group memberships.
Will add or remove roles of a user as needed.
@@ -109,57 +110,18 @@ class DiscordUser(models.Model):
- None if user is no longer a member of the Discord server
- False on error or raises exception
"""
client = DiscordUser.objects._bot_client()
member_roles = self._determine_member_roles(client)
if member_roles is None:
new_roles, is_changed = calculate_roles_for_user(
user=self.user,
client=default_bot_client,
discord_uid=self.uid,
state_name=state_name
)
if is_changed is None:
logger.debug('User is not a member of this guild %s', self.user)
return None
return self._update_roles_if_needed(client, state_name, member_roles)
def _determine_member_roles(self, client: DiscordClient) -> DiscordRoles:
"""Determine the roles of the current member / user."""
member_info = client.guild_member(guild_id=DISCORD_GUILD_ID, user_id=self.uid)
if member_info is None:
return None # User is no longer a member
guild_roles = DiscordRoles(client.guild_roles(guild_id=DISCORD_GUILD_ID))
logger.debug('Current guild roles: %s', guild_roles.ids())
if 'roles' in member_info:
if not guild_roles.has_roles(member_info['roles']):
guild_roles = DiscordRoles(
client.guild_roles(guild_id=DISCORD_GUILD_ID, use_cache=False)
)
if not guild_roles.has_roles(member_info['roles']):
raise RuntimeError(
'Member {} has unknown roles: {}'.format(
self.user,
set(member_info['roles']).difference(guild_roles.ids())
)
)
return guild_roles.subset(member_info['roles'])
raise RuntimeError('member_info from %s is not valid' % self.user)
def _update_roles_if_needed(
self, client: DiscordClient, state_name: str, member_roles: DiscordRoles
) -> bool:
"""Update the roles of this member/user if needed."""
requested_roles = match_or_create_roles_from_names(
client=client,
guild_id=DISCORD_GUILD_ID,
role_names=DiscordUser.objects.user_group_names(
user=self.user, state_name=state_name
)
)
logger.debug(
'Requested roles for user %s: %s', self.user, requested_roles.ids()
)
logger.debug('Current roles user %s: %s', self.user, member_roles.ids())
reserved_role_names = ReservedGroupName.objects.values_list("name", flat=True)
member_roles_reserved = member_roles.subset(role_names=reserved_role_names)
member_roles_managed = member_roles.subset(managed_only=True)
member_roles_persistent = member_roles_managed.union(member_roles_reserved)
if requested_roles != member_roles.difference(member_roles_persistent):
if is_changed:
logger.debug('Need to update roles for user %s', self.user)
new_roles = requested_roles.union(member_roles_persistent)
success = client.modify_guild_member(
success = default_bot_client.modify_guild_member(
guild_id=DISCORD_GUILD_ID,
user_id=self.uid,
role_ids=list(new_roles.ids())
@@ -172,41 +134,32 @@ class DiscordUser(models.Model):
logger.info('No need to update roles for user %s', self.user)
return True
def update_username(self) -> bool:
def update_username(self) -> Optional[bool]:
"""Updates the username incl. the discriminator
from the Discord server and saves it
Returns:
- True on success
- None if user is no longer a member of the Discord server
- False on error or raises exception
"""
client = DiscordUser.objects._bot_client()
user_info = client.guild_member(guild_id=DISCORD_GUILD_ID, user_id=self.uid)
if user_info is None:
success = None
elif (
user_info
and 'user' in user_info
and 'username' in user_info['user']
and 'discriminator' in user_info['user']
):
self.username = user_info['user']['username']
self.discriminator = user_info['user']['discriminator']
self.save()
logger.info('Username for %s has been updated', self.user)
success = True
else:
logger.warning('Failed to update username for %s', self.user)
success = False
return success
member_info = default_bot_client.guild_member(
guild_id=DISCORD_GUILD_ID, user_id=self.uid
)
if not member_info:
logger.warning('%s: User not a guild member', self.user)
return None
self.username = member_info.user.username
self.discriminator = member_info.user.discriminator
self.save()
logger.info('%s: Username has been updated', self.user)
return True
def delete_user(
self,
notify_user: bool = False,
is_rate_limited: bool = True,
handle_api_exceptions: bool = False
) -> bool:
) -> Optional[bool]:
"""Deletes the Discount user both on the server and locally
Params:
@@ -221,7 +174,7 @@ class DiscordUser(models.Model):
"""
try:
_user = self.user
client = DiscordUser.objects._bot_client(is_rate_limited=is_rate_limited)
client = create_bot_client(is_rate_limited=is_rate_limited)
success = client.remove_guild_member(
guild_id=DISCORD_GUILD_ID, user_id=self.uid
)
@@ -241,15 +194,13 @@ class DiscordUser(models.Model):
)
logger.info('Account for user %s was deleted.', _user)
return True
else:
logger.debug('Account for user %s was already deleted.', _user)
return None
logger.debug('Account for user %s was already deleted.', _user)
return None
else:
logger.warning(
'Failed to remove user %s from the Discord server', _user
)
return False
logger.warning(
'Failed to remove user %s from the Discord server', _user
)
return False
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
if handle_api_exceptions:
@@ -257,5 +208,4 @@ class DiscordUser(models.Model):
'Failed to remove user %s from Discord server: %s',self.user, ex
)
return False
else:
raise ex
raise ex

View File

@@ -1,19 +1,6 @@
from django.contrib.auth.models import Group
from allianceauth.tests.auth_utils import AuthUtils
from ..discord_client.tests import ( # noqa
TEST_GUILD_ID,
TEST_USER_ID,
TEST_USER_NAME,
TEST_USER_DISCRIMINATOR,
create_role,
ROLE_ALPHA,
ROLE_BRAVO,
ROLE_CHARLIE,
ROLE_CHARLIE_2,
ROLE_MIKE,
ALL_ROLES,
create_user_info
)
DEFAULT_AUTH_GROUP = 'Member'
MODULE_PATH = 'allianceauth.services.modules.discord'

View File

@@ -0,0 +1,31 @@
from django.utils.timezone import now
from allianceauth.authentication.backends import StateBackend
from allianceauth.tests.auth_utils import AuthUtils
from ..discord_client.tests.factories import (
TEST_USER_DISCRIMINATOR,
TEST_USER_ID,
TEST_USER_NAME,
)
from ..models import DiscordUser
def create_user(**kwargs):
params = {"username": TEST_USER_NAME}
params.update(kwargs)
username = StateBackend.iterate_username(params["username"])
user = AuthUtils.create_user(username)
return AuthUtils.add_permission_to_user_by_name("discord.access_discord", user)
def create_discord_user(user=None, **kwargs):
params = {
"user": user or create_user(),
"uid": TEST_USER_ID,
"username": TEST_USER_NAME,
"discriminator": TEST_USER_DISCRIMINATOR,
"activated": now(),
}
params.update(kwargs)
return DiscordUser.objects.create(**params)

View File

@@ -35,17 +35,17 @@ import logging
from uuid import uuid1
import random
from django.core.cache import caches
from django.contrib.auth.models import User, Group
from allianceauth.services.modules.discord.models import DiscordUser
from allianceauth.utils.cache import get_redis_client
logger = logging.getLogger('allianceauth')
MAX_RUNS = 3
def clear_cache():
default_cache = caches['default']
redis = default_cache.get_master_client()
redis = get_redis_client()
redis.flushall()
logger.info('Cache flushed')

View File

@@ -1,26 +1,32 @@
from django.test import TestCase, RequestFactory
from unittest.mock import patch
from django.contrib.admin.sites import AdminSite
from django.contrib.auth.models import User
from django.test import RequestFactory
from django.utils.timezone import now
from allianceauth.authentication.models import CharacterOwnership
from allianceauth.eveonline.models import (
EveCharacter, EveCorporationInfo, EveAllianceInfo
EveAllianceInfo,
EveCharacter,
EveCorporationInfo,
)
from allianceauth.utils.testing import NoSocketsTestCase
from ....admin import (
MainAllianceFilter,
MainCorporationsFilter,
ServicesUserAdmin,
user_main_organization,
user_profile_pic,
user_username,
user_main_organization,
ServicesUserAdmin,
MainCorporationsFilter,
MainAllianceFilter
)
from ..admin import DiscordUserAdmin
from ..models import DiscordUser
from . import MODULE_PATH
class TestDataMixin(TestCase):
class TestDataMixin(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -168,7 +174,7 @@ class TestDataMixin(TestCase):
)
class TestColumnRendering(TestDataMixin, TestCase):
class TestColumnRendering(TestDataMixin, NoSocketsTestCase):
def test_user_profile_pic_u1(self):
expected = (
@@ -229,7 +235,7 @@ class TestColumnRendering(TestDataMixin, TestCase):
# actions
class TestFilters(TestDataMixin, TestCase):
class TestFilters(TestDataMixin, NoSocketsTestCase):
def test_filter_main_corporations(self):
@@ -287,3 +293,16 @@ class TestFilters(TestDataMixin, TestCase):
queryset = changelist.get_queryset(request)
expected = [self.user_1.discord]
self.assertSetEqual(set(queryset), set(expected))
@patch(MODULE_PATH + ".admin.DiscordUser.delete_user")
class TestDeleteQueryset(TestDataMixin, NoSocketsTestCase):
def test_should_delete_all_objects(self, mock_delete_user):
# given
request = self.factory.get('/')
request.user = self.user_1
queryset = DiscordUser.objects.filter(user__in=[self.user_2, self.user_3])
# when
self.modeladmin.delete_queryset(request, queryset)
# then
self.assertEqual(mock_delete_user.call_count, 2)

View File

@@ -0,0 +1,16 @@
from unittest.mock import patch
from allianceauth.utils.testing import NoSocketsTestCase
from ..api import discord_guild_id
from . import MODULE_PATH
class TestDiscordGuildId(NoSocketsTestCase):
@patch(MODULE_PATH + ".api.DISCORD_GUILD_ID", "123")
def test_should_return_guild_id_when_configured(self):
self.assertEqual(discord_guild_id(), 123)
@patch(MODULE_PATH + ".api.DISCORD_GUILD_ID", "")
def test_should_return_none_when_not_configured(self):
self.assertIsNone(discord_guild_id())

View File

@@ -1,23 +1,23 @@
from unittest.mock import patch
from django.test import TestCase, RequestFactory
from django.test import RequestFactory
from django.test.utils import override_settings
from allianceauth.notifications.models import Notification
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
from . import TEST_USER_NAME, TEST_USER_ID, add_permissions_to_members, MODULE_PATH
from ..auth_hooks import DiscordService
from ..discord_client import DiscordClient
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
from ..models import DiscordUser
from ..utils import set_logger_to_file
from . import MODULE_PATH, add_permissions_to_members
logger = set_logger_to_file(MODULE_PATH + '.auth_hooks', __file__)
@override_settings(CELERY_ALWAYS_EAGER=True)
class TestDiscordService(TestCase):
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class TestDiscordService(NoSocketsTestCase):
def setUp(self):
self.member = AuthUtils.create_member(TEST_USER_NAME)
@@ -64,11 +64,11 @@ class TestDiscordService(TestCase):
@patch(MODULE_PATH + '.models.notify')
@patch(MODULE_PATH + '.tasks.DiscordUser')
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
@patch(MODULE_PATH + '.models.create_bot_client')
def test_validate_user(
self, mock_DiscordClient, mock_DiscordUser, mock_notify
self, mock_create_bot_client, mock_DiscordUser, mock_notify
):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
mock_create_bot_client.return_value.remove_guild_member.return_value = True
# Test member is not deleted
service = self.service()
@@ -92,33 +92,38 @@ class TestDiscordService(TestCase):
service.sync_nicknames_bulk([self.member])
self.assertTrue(mock_update_nicknames_bulk.delay.called)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
def test_delete_user_is_member(self, mock_DiscordClient):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
@patch(MODULE_PATH + '.models.create_bot_client')
def test_delete_user_is_member(self, mock_create_bot_client):
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
service = self.service()
# when
service.delete_user(self.member, notify_user=True)
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
# then
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
self.assertFalse(DiscordUser.objects.filter(user=self.member).exists())
self.assertTrue(Notification.objects.filter(user=self.member).exists())
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
def test_delete_user_is_not_member(self, mock_DiscordClient):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
@patch(MODULE_PATH + '.models.create_bot_client')
def test_delete_user_is_not_member(self, mock_create_bot_client):
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
service = self.service()
# when
service.delete_user(self.none_member)
# then
self.assertFalse(mock_create_bot_client.return_value.remove_guild_member.called)
self.assertFalse(mock_DiscordClient.return_value.remove_guild_member.called)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
def test_render_services_ctrl_with_username(self, mock_DiscordClient):
@patch(MODULE_PATH + '.auth_hooks.server_name')
def test_render_services_ctrl_with_username(self, mock_server_name):
# given
mock_server_name.return_value = "My server"
service = self.service()
request = self.factory.get('/services/')
request.user = self.member
# when
response = service.render_services_ctrl(request)
# then
self.assertTemplateUsed(service.service_ctrl_template)
self.assertIn('/discord/reset/', response)
self.assertIn('/discord/deactivate/', response)
@@ -130,15 +135,18 @@ class TestDiscordService(TestCase):
response = service.render_services_ctrl(request)
self.assertIn('/discord/activate/', response)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
def test_render_services_ctrl_wo_username(self, mock_DiscordClient):
@patch(MODULE_PATH + '.auth_hooks.server_name')
def test_render_services_ctrl_wo_username(self, mock_server_name):
# given
mock_server_name.return_value = "My server"
my_member = AuthUtils.create_member('John Doe')
DiscordUser.objects.create(user=my_member, uid=111222333)
service = self.service()
request = self.factory.get('/services/')
request.user = my_member
# when
response = service.render_services_ctrl(request)
# then
self.assertTemplateUsed(service.service_ctrl_template)
self.assertIn('/discord/reset/', response)
self.assertIn('/discord/deactivate/', response)

View File

@@ -0,0 +1,221 @@
from unittest.mock import Mock, patch
from requests.exceptions import HTTPError
from django.contrib.auth.models import Group
from allianceauth.groupmanagement.models import ReservedGroupName
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
from ..core import (
_user_group_names,
calculate_roles_for_user,
group_to_role,
server_name,
user_formatted_nick,
)
from ..discord_client import DiscordApiBackoff, DiscordClient, RolesSet
from ..discord_client.tests.factories import TEST_USER_NAME, create_role
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME
class TestUserGroupNames(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.group_1 = Group.objects.create(name="Group 1")
cls.group_2 = Group.objects.create(name="Group 2")
def setUp(self):
self.user = AuthUtils.create_member(TEST_USER_NAME)
def test_return_groups_and_state_names_for_user(self):
self.user.groups.add(self.group_1)
result = _user_group_names(self.user)
expected = ["Group 1", "Member"]
self.assertSetEqual(set(result), set(expected))
def test_return_state_only_if_user_has_no_groups(self):
result = _user_group_names(self.user)
expected = ["Member"]
self.assertSetEqual(set(result), set(expected))
class TestUserFormattedNick(NoSocketsTestCase):
def setUp(self):
self.user = AuthUtils.create_user(TEST_USER_NAME)
def test_return_nick_when_user_has_main(self):
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
result = user_formatted_nick(self.user)
expected = TEST_MAIN_NAME
self.assertEqual(result, expected)
def test_return_none_if_user_has_no_main(self):
result = user_formatted_nick(self.user)
self.assertIsNone(result)
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
class TestRoleForGroup(NoSocketsTestCase):
def test_return_role_if_found(self, mock_bot_client):
# given
role = create_role(name="alpha")
mock_bot_client.match_role_from_name.side_effect = (
lambda guild_id, role_name: role if role.name == role_name else None
)
group = Group.objects.create(name="alpha")
# when/then
self.assertEqual(group_to_role(group), role)
def test_return_empty_dict_if_not_found(self, mock_bot_client):
# given
role = create_role(name="alpha")
mock_bot_client.match_role_from_name.side_effect = (
lambda guild_id, role_name: role if role.name == role_name else None
)
group = Group.objects.create(name="unknown")
# when/then
self.assertIsNone(group_to_role(group))
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
@patch(MODULE_PATH + ".core.logger", spec=True)
class TestServerName(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.user = AuthUtils.create_user(TEST_USER_NAME)
def test_returns_name_when_api_returns_it(self, mock_logger, mock_bot_client):
# given
my_server_name = "El Dorado"
mock_bot_client.guild_name.return_value = my_server_name
# when
self.assertEqual(server_name(), my_server_name)
# then
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_http_error(
self, mock_logger, mock_bot_client
):
mock_exception = HTTPError("Test exception")
mock_exception.response = Mock(**{"status_code": 440})
mock_bot_client.guild_name.side_effect = mock_exception
self.assertEqual(server_name(), "")
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_service_error(
self, mock_logger, mock_bot_client
):
mock_bot_client.guild_name.side_effect = DiscordApiBackoff(1000)
self.assertEqual(server_name(), "")
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_unexpected_error(
self, mock_logger, mock_bot_client
):
mock_bot_client.guild_name.side_effect = RuntimeError
self.assertEqual(server_name(), "")
self.assertTrue(mock_logger.warning.called)
class TestCalculateRolesForUser(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.user = AuthUtils.create_user(TEST_USER_NAME)
def test_should_return_roles_for_new_member(self):
# given
roles = RolesSet([create_role()])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = RolesSet([])
my_client.match_or_create_roles_from_names_2.return_value = roles
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertTrue(changed)
self.assertEqual(roles_calculated, roles)
def test_should_return_changed_roles_for_existing_member(self):
# given
role_a = create_role()
role_b = create_role()
roles_current = RolesSet([role_a])
roles_matching = RolesSet([role_a, role_b])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = roles_current
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertTrue(changed)
self.assertEqual(roles_calculated, roles_matching)
def test_should_indicate_when_roles_are_unchanged(self):
# given
role_a = create_role()
roles_current = RolesSet([role_a])
roles_matching = RolesSet([role_a])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = roles_current
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertFalse(changed)
self.assertEqual(roles_calculated, roles_matching)
def test_should_indicate_when_user_is_no_guild_member(self):
# given
role_a = create_role()
roles_matching = RolesSet([role_a])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = None
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertIsNone(changed)
self.assertEqual(roles_calculated, roles_matching)
def test_should_preserve_managed_roles_for_existing_member(self):
# given
role_a = create_role()
role_b = create_role()
role_m = create_role(managed=True)
roles_current = RolesSet([role_a, role_m])
roles_matching = RolesSet([role_b])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = roles_current
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertTrue(changed)
self.assertEqual(roles_calculated, RolesSet([role_b, role_m]))
def test_should_preserve_reserved_roles_for_existing_member(self):
# given
role_a = create_role()
role_b = create_role()
role_c1 = create_role(name="charlie")
role_c2 = create_role(name="Charlie")
roles_current = RolesSet([role_a, role_c1, role_c2])
roles_matching = RolesSet([role_b])
my_client = Mock(spec=DiscordClient)
my_client.guild_member_roles.return_value = roles_current
my_client.match_or_create_roles_from_names_2.return_value = roles_matching
ReservedGroupName.objects.create(
name="charlie", reason="dummy", created_by="xyz"
)
# when
roles_calculated, changed = calculate_roles_for_user(self.user, my_client, 42)
# then
self.assertTrue(changed)
self.assertEqual(roles_calculated, RolesSet([role_b, role_c1, role_c2]))

View File

@@ -4,54 +4,68 @@ Testing all components of the service, with the exception of the Discord API.
Please note that these tests require Redis and will flush it
"""
from collections import namedtuple
import dataclasses
import json
import logging
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
from uuid import uuid1
from django_webtest import WebTest
from requests.exceptions import HTTPError
import requests_mock
from requests.exceptions import HTTPError
from django.contrib.auth.models import Group, User
from django.core.cache import caches
from django.shortcuts import reverse
from django.test import TransactionTestCase, TestCase
from django.test.utils import override_settings
from django.test import TransactionTestCase, override_settings
from django.urls import reverse
from django_webtest import WebTest
from allianceauth.authentication.models import State
from allianceauth.eveonline.models import EveCharacter
from allianceauth.groupmanagement.models import ReservedGroupName
from allianceauth.notifications.models import Notification
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.cache import get_redis_client
from allianceauth.utils.testing import NoSocketsTestCase
from . import (
TEST_GUILD_ID,
TEST_USER_NAME,
TEST_USER_ID,
TEST_USER_DISCRIMINATOR,
TEST_MAIN_NAME,
TEST_MAIN_ID,
MODULE_PATH,
add_permissions_to_members,
ROLE_ALPHA,
ROLE_BRAVO,
ROLE_CHARLIE,
ROLE_MIKE,
create_role,
create_user_info
)
from ..discord_client.app_settings import DISCORD_API_BASE_URL
from ..discord_client.exceptions import DiscordApiBackoff
from ..models import DiscordUser
from .. import tasks
from ..core import create_bot_client
from ..discord_client import DiscordApiBackoff
from ..discord_client.app_settings import DISCORD_API_BASE_URL
from ..discord_client.tests.factories import (
TEST_GUILD_ID,
TEST_USER_ID,
TEST_USER_NAME,
create_discord_error_response_unknown_member,
create_discord_guild_member_object,
create_discord_guild_object,
create_discord_role_object,
create_discord_user_object,
)
from ..models import DiscordUser
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME, add_permissions_to_members
from .factories import create_discord_user, create_user
logger = logging.getLogger('allianceauth')
ROLE_MEMBER = create_role(99, 'Member')
ROLE_BLUE = create_role(98, 'Blue')
ROLE_ALPHA = create_discord_role_object(id=1, name="alpha")
ROLE_BRAVO = create_discord_role_object(id=2, name="bravo")
ROLE_CHARLIE = create_discord_role_object(id=3, name="charlie")
ROLE_CHARLIE_2 = create_discord_role_object(id=4, name="Charlie") # Discord roles are case sensitive
ROLE_MIKE = create_discord_role_object(id=13, name="mike", managed=True)
ROLE_MEMBER = create_discord_role_object(99, 'Member')
ROLE_BLUE = create_discord_role_object(98, 'Blue')
@dataclasses.dataclass(frozen=True)
class DiscordRequest:
"""Helper for comparing requests made to the Discord API."""
method: str
url: str
text: str = dataclasses.field(compare=False, default=None)
def json(self):
return json.loads(self.text)
# Putting all requests to Discord into objects so we can compare them better
DiscordRequest = namedtuple('DiscordRequest', ['method', 'url'])
user_get_current_request = DiscordRequest(
method='GET',
url=f'{DISCORD_API_BASE_URL}users/@me'
@@ -87,8 +101,7 @@ remove_guild_member_request = DiscordRequest(
def clear_cache():
default_cache = caches['default']
redis = default_cache.get_master_client()
redis = get_redis_client()
redis.flushall()
logger.info('Cache flushed')
@@ -103,13 +116,13 @@ def reset_testdata():
Notification.objects.all().delete()
@patch(MODULE_PATH + '.core.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
@override_settings(CELERY_ALWAYS_EAGER=True)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=False)
@requests_mock.Mocker()
class TestServiceFeatures(TransactionTestCase):
fixtures = ['disable_analytics.json']
@classmethod
def setUpClass(cls):
super().setUpClass()
@@ -193,7 +206,7 @@ class TestServiceFeatures(TransactionTestCase):
requests_mocker.patch(modify_guild_member_request.url, status_code=204)
# exhausting rate limit
client = DiscordUser.objects._bot_client()
client = create_bot_client()
client._redis.set(
name=client._KEY_GLOBAL_RATE_LIMIT_REMAINING,
value=0,
@@ -209,7 +222,6 @@ class TestServiceFeatures(TransactionTestCase):
requests_made = [
DiscordRequest(r.method, r.url) for r in requests_mocker.request_history
]
self.assertListEqual(requests_made, list())
def test_when_member_is_demoted_to_guest_then_his_account_is_deleted(
@@ -247,7 +259,7 @@ class TestServiceFeatures(TransactionTestCase):
# request mocks
requests_mocker.get(
guild_member_request.url,
json={'user': create_user_info(), 'roles': ['3', '13', '99']}
json=create_discord_guild_member_object(roles=[3, 13, 99])
)
requests_mocker.get(
guild_roles_request.url,
@@ -283,10 +295,7 @@ class TestServiceFeatures(TransactionTestCase):
):
requests_mocker.get(
guild_member_request.url,
json={
'user': create_user_info(),
'roles': ['13', '99']
}
json=create_discord_guild_member_object(roles=[13, 99])
)
requests_mocker.get(
guild_roles_request.url,
@@ -315,10 +324,7 @@ class TestServiceFeatures(TransactionTestCase):
):
requests_mocker.get(
guild_member_request.url,
json={
'user': {'id': str(TEST_USER_ID), 'username': TEST_MAIN_NAME},
'roles': ['13', '99']
}
json=create_discord_guild_member_object(roles=['13', '99'])
)
requests_mocker.get(
guild_roles_request.url,
@@ -344,11 +350,33 @@ class TestServiceFeatures(TransactionTestCase):
self.assertTrue(DiscordUser.objects.user_has_account(self.user))
@override_settings(CELERY_ALWAYS_EAGER=True)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
@requests_mock.Mocker()
class StateTestCase(TestCase):
class TestTasks(NoSocketsTestCase):
def test_should_update_username(self, requests_mocker):
# given
user = create_user()
discord_user = create_discord_user(user)
discord_user_obj = create_discord_user_object()
data = create_discord_guild_member_object(user=discord_user_obj)
requests_mocker.get(guild_member_request.url, json=data)
# when
tasks.update_username.delay(user.pk)
# then
discord_user.refresh_from_db()
self.assertEqual(discord_user.username, discord_user_obj["username"])
self.assertEqual(
discord_user.discriminator, discord_user_obj["discriminator"]
)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
@requests_mock.Mocker()
class StateTestCase(NoSocketsTestCase):
def setUp(self):
clear_cache()
@@ -432,6 +460,7 @@ class StateTestCase(TestCase):
self.user.discord
@patch(MODULE_PATH + '.core.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.models.DISCORD_GUILD_ID', TEST_GUILD_ID)
@requests_mock.Mocker()
@@ -450,24 +479,25 @@ class TestUserFeatures(WebTest):
)
add_permissions_to_members()
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.managers.OAuth2Session')
@patch(MODULE_PATH + '.views.messages', spec=True)
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
def test_user_activation_normal(
self, requests_mocker, mock_OAuth2Session, mock_messages
):
# setup
requests_mocker.get(
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
guild_infos_request.url, json=create_discord_guild_object()
)
requests_mocker.get(
user_get_current_request.url,
json=create_user_info(
TEST_USER_ID, TEST_USER_NAME, TEST_USER_DISCRIMINATOR
)
user_get_current_request.url, json=create_discord_user_object()
)
requests_mocker.get(
guild_roles_request.url,
json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE, ROLE_MEMBER]
guild_roles_request.url, json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MEMBER]
)
requests_mocker.get(
guild_member_request.url,
status_code=404,
json=create_discord_error_response_unknown_member()
)
requests_mocker.put(add_guild_member_request.url, status_code=201)
@@ -505,33 +535,93 @@ class TestUserFeatures(WebTest):
for r in requests_mocker.request_history:
obj = DiscordRequest(r.method, r.url)
requests_made.append(obj)
self.assertIn(add_guild_member_request, requests_made)
expected = [
guild_infos_request,
user_get_current_request,
guild_roles_request,
add_guild_member_request
]
self.assertListEqual(requests_made, expected)
@patch(MODULE_PATH + '.views.messages', spec=True)
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
def test_should_activate_existing_user_and_keep_managed_and_reserved_roles(
self, requests_mocker, mock_OAuth2Session, mock_messages
):
# setup
requests_mocker.get(
guild_infos_request.url, json=create_discord_guild_object()
)
requests_mocker.get(
user_get_current_request.url, json=create_discord_user_object()
)
requests_mocker.get(
guild_roles_request.url, json=[
ROLE_ALPHA, ROLE_CHARLIE, ROLE_MEMBER, ROLE_MIKE
]
)
requests_mocker.get(
guild_member_request.url,
json=create_discord_guild_member_object(roles=[1, 3, 13])
)
requests_mocker.patch(modify_guild_member_request.url, status_code=204)
ReservedGroupName.objects.create(
name="charlie", reason="dummy", created_by="xyz"
)
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.managers.OAuth2Session')
authentication_code = 'auth_code'
oauth_url = 'https://www.example.com/oauth'
state = ''
mock_OAuth2Session.return_value.authorization_url.return_value = \
oauth_url, state
# login
self.app.set_user(self.member)
# user opens services page
services_page = self.app.get(reverse('services:services'))
self.assertEqual(services_page.status_code, 200)
# user clicks Discord service activation link on page
response = services_page.click(href=reverse('discord:activate'))
# check we got a redirect to Discord OAuth
self.assertRedirects(
response, expected_url=oauth_url, fetch_redirect_response=False
)
# simulate Discord callback
response = self.app.get(
reverse('discord:callback'), params={'code': authentication_code}
)
# user got a success message
self.assertTrue(mock_messages.success.called)
self.assertFalse(mock_messages.error.called)
my_request = None
for r in requests_mocker.request_history:
obj = DiscordRequest(r.method, r.url, r.text)
if obj == modify_guild_member_request:
my_request = obj
break
else:
self.fail("Request not found")
self.assertSetEqual(set(my_request.json()["roles"]), {3, 13, 99})
@patch(MODULE_PATH + '.views.messages', spec=True)
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
def test_user_activation_failed(
self, requests_mocker, mock_OAuth2Session, mock_messages
):
# setup
requests_mocker.get(
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
guild_infos_request.url, json=create_discord_guild_object()
)
requests_mocker.get(
user_get_current_request.url,
json=create_user_info(
TEST_USER_ID, TEST_USER_NAME, TEST_USER_DISCRIMINATOR
)
user_get_current_request.url, json=create_discord_user_object()
)
requests_mocker.get(
guild_roles_request.url,
json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE, ROLE_MEMBER]
guild_roles_request.url, json=[ROLE_ALPHA, ROLE_BRAVO, ROLE_MEMBER]
)
requests_mocker.get(
guild_member_request.url,
status_code=404,
json=create_discord_error_response_unknown_member()
)
mock_exception = HTTPError('error')
@@ -573,20 +663,13 @@ class TestUserFeatures(WebTest):
for r in requests_mocker.request_history:
obj = DiscordRequest(r.method, r.url)
requests_made.append(obj)
self.assertIn(add_guild_member_request, requests_made)
expected = [
guild_infos_request,
user_get_current_request,
guild_roles_request,
add_guild_member_request
]
self.assertListEqual(requests_made, expected)
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.views.messages', spec=True)
def test_user_deactivation_normal(self, requests_mocker, mock_messages):
# setup
requests_mocker.get(
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
guild_infos_request.url, json=create_discord_guild_object()
)
requests_mocker.delete(remove_guild_member_request.url, status_code=204)
DiscordUser.objects.create(user=self.member, uid=TEST_USER_ID)
@@ -612,15 +695,13 @@ class TestUserFeatures(WebTest):
for r in requests_mocker.request_history:
obj = DiscordRequest(r.method, r.url)
requests_made.append(obj)
self.assertIn(remove_guild_member_request, requests_made)
expected = [guild_infos_request, remove_guild_member_request]
self.assertListEqual(requests_made, expected)
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.views.messages', spec=True)
def test_user_deactivation_fails(self, requests_mocker, mock_messages):
# setup
requests_mocker.get(
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
guild_infos_request.url, json=create_discord_guild_object()
)
mock_exception = HTTPError('error')
mock_exception.response = Mock()
@@ -650,11 +731,9 @@ class TestUserFeatures(WebTest):
for r in requests_mocker.request_history:
obj = DiscordRequest(r.method, r.url)
requests_made.append(obj)
self.assertIn(remove_guild_member_request, requests_made)
expected = [guild_infos_request, remove_guild_member_request]
self.assertListEqual(requests_made, expected)
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.views.messages', spec=True)
def test_user_add_new_server(self, requests_mocker, mock_messages):
# setup
mock_exception = HTTPError(Mock(**{"response.status_code": 400}))
@@ -686,14 +765,13 @@ class TestUserFeatures(WebTest):
services_page = self.app.get(reverse('services:services'))
self.assertEqual(services_page.status_code, 200)
@override_settings(CELERY_ALWAYS_EAGER=True)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
@patch(MODULE_PATH + ".core.default_bot_client", spec=True)
def test_server_name_is_updated_by_task(
self, requests_mocker
self, requests_mocker, mock_bot_client
):
# setup
requests_mocker.get(
guild_infos_request.url, json={'id': TEST_GUILD_ID, 'name': 'Test Guild'}
)
mock_bot_client.guild_name.return_value = "Test Guild"
# run task to update usernames
tasks.update_all_usernames()

View File

@@ -1,364 +1,395 @@
from unittest.mock import patch, Mock
import urllib
from unittest.mock import Mock, patch
from requests.exceptions import HTTPError
from django.contrib.auth.models import Group, User
from django.test import TestCase
from django.contrib.auth.models import User
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
from . import (
from ..app_settings import DISCORD_APP_ID, DISCORD_APP_SECRET, DISCORD_CALLBACK_URL
from ..discord_client import (
DISCORD_OAUTH_BASE_URL,
DISCORD_OAUTH_TOKEN_URL,
DiscordApiBackoff,
DiscordClient,
RolesSet,
)
from ..discord_client.tests.factories import (
TEST_GUILD_ID,
TEST_USER_NAME,
TEST_USER_ID,
TEST_MAIN_NAME,
TEST_MAIN_ID,
MODULE_PATH,
ROLE_ALPHA,
ROLE_BRAVO,
ROLE_CHARLIE,
TEST_USER_NAME,
create_role,
create_user,
)
from ..discord_client.tests import create_matched_role
from ..app_settings import (
DISCORD_APP_ID,
DISCORD_APP_SECRET,
DISCORD_CALLBACK_URL,
)
from ..discord_client import DiscordClient, DiscordApiBackoff
from ..models import DiscordUser
from ..utils import set_logger_to_file
from . import MODULE_PATH, TEST_MAIN_NAME
logger = set_logger_to_file(MODULE_PATH + '.managers', __file__)
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
@patch(MODULE_PATH + '.models.DiscordUser.objects._exchange_auth_code_for_token')
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_formatted_nick')
class TestAddUser(TestCase):
@patch(MODULE_PATH + '.managers.create_bot_client', spec=True)
@patch(
MODULE_PATH + '.models.DiscordUser.objects._exchange_auth_code_for_token', spec=True
)
@patch(MODULE_PATH + '.managers.calculate_roles_for_user', spec=True)
@patch(MODULE_PATH + '.managers.user_formatted_nick', spec=True)
class TestAddUser(NoSocketsTestCase):
def setUp(self):
self.user = AuthUtils.create_user(TEST_USER_NAME)
self.user_info = {
'id': TEST_USER_ID,
'name': TEST_USER_NAME,
'username': TEST_USER_NAME,
'discriminator': '1234',
}
self.access_token = 'accesstoken'
def test_can_create_user_no_roles_no_nick(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_create_bot_client.return_value.add_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertEqual(kwargs['access_token'], self.access_token)
self.assertIsNone(kwargs['role_ids'])
self.assertFalse(kwargs['role_ids'])
self.assertIsNone(kwargs['nick'])
def test_can_create_user_with_roles_no_nick(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
roles = [
create_matched_role(ROLE_ALPHA),
create_matched_role(ROLE_BRAVO),
create_matched_role(ROLE_CHARLIE)
]
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
roles_calculated = RolesSet([role_a, role_b])
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = ['a', 'b', 'c']
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = roles
mock_DiscordClient.return_value.add_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = roles_calculated, None
mock_create_bot_client.return_value.add_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertEqual(kwargs['access_token'], self.access_token)
self.assertSetEqual(set(kwargs['role_ids']), {1, 2, 3})
self.assertSetEqual(set(kwargs['role_ids']), {1, 2})
self.assertIsNone(kwargs['nick'])
def test_can_activate_existing_user_with_roles_no_nick(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
roles = [
create_matched_role(ROLE_ALPHA),
create_matched_role(ROLE_BRAVO),
create_matched_role(ROLE_CHARLIE)
]
# given
role_a = create_role(id=1)
role_b = create_role(id=2)
roles_calculated = RolesSet([role_a, role_b])
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = ['a', 'b', 'c']
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = roles
mock_DiscordClient.return_value.add_guild_member.return_value = None
mock_DiscordClient.return_value.modify_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = roles_calculated, False
mock_create_bot_client.return_value.add_guild_member.return_value = None
mock_create_bot_client.return_value.modify_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertSetEqual(set(kwargs['role_ids']), {1, 2, 3})
self.assertSetEqual(set(kwargs['role_ids']), {1, 2})
self.assertIsNone(kwargs['nick'])
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', True)
def test_can_create_user_no_roles_with_nick(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_create_bot_client.return_value.add_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertEqual(kwargs['access_token'], self.access_token)
self.assertIsNone(kwargs['role_ids'])
self.assertFalse(kwargs['role_ids'])
self.assertEqual(kwargs['nick'], TEST_MAIN_NAME)
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', True)
def test_can_activate_existing_user_no_roles_with_nick(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = None
mock_DiscordClient.return_value.modify_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), False
mock_create_bot_client.return_value.add_guild_member.return_value = None
mock_create_bot_client.return_value.modify_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertIsNone(kwargs['role_ids'])
self.assertFalse(kwargs['role_ids'])
self.assertEqual(kwargs['nick'], TEST_MAIN_NAME)
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', False)
def test_can_create_user_no_roles_and_without_nick_if_turned_off(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_create_bot_client.return_value.add_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
_, kwargs = mock_create_bot_client.return_value.add_guild_member.call_args
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
self.assertEqual(kwargs['access_token'], self.access_token)
self.assertIsNone(kwargs['role_ids'])
self.assertFalse(kwargs['role_ids'])
self.assertIsNone(kwargs['nick'])
def test_can_activate_existing_guild_member(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
roles_calculated = RolesSet([create_role()])
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = None
mock_DiscordClient.return_value.modify_guild_member.return_value = True
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = roles_calculated, False
mock_create_bot_client.return_value.add_guild_member.return_value = None
mock_create_bot_client.return_value.modify_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.modify_guild_member.called)
def test_can_activate_existing_member_with_roles(
self,
mock_user_formatted_nick,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
roles_calculated = RolesSet([create_role(id=1)])
mock_user_formatted_nick.return_value = None
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = roles_calculated, False
mock_create_bot_client.return_value.add_guild_member.return_value = None
mock_create_bot_client.return_value.modify_guild_member.return_value = True
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertTrue(result)
self.assertTrue(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
_, kwargs = mock_create_bot_client.return_value.modify_guild_member.call_args
self.assertSetEqual(set(kwargs['role_ids']), {1})
def test_can_activate_existing_guild_member_failure(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
roles_calculated = RolesSet([create_role()])
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = None
mock_DiscordClient.return_value.modify_guild_member.return_value = False
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = roles_calculated, False
mock_create_bot_client.return_value.add_guild_member.return_value = None
mock_create_bot_client.return_value.modify_guild_member.return_value = False
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertFalse(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.modify_guild_member.called)
def test_return_false_when_user_creation_fails(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.return_value = False
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_create_bot_client.return_value.add_guild_member.return_value = False
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertFalse(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
def test_return_false_when_on_api_backoff(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.add_guild_member.side_effect = \
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_create_bot_client.return_value.add_guild_member.side_effect = \
DiscordApiBackoff(999)
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertFalse(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
def test_return_false_on_http_error(
self,
mock_user_formatted_nick,
mock_user_group_names,
mock_calculate_roles_for_user,
mock_exchange_auth_code_for_token,
mock_DiscordClient
mock_create_bot_client,
mock_DiscordClient,
):
# given
discord_user = create_user(id=TEST_USER_ID)
mock_user_formatted_nick.return_value = None
mock_user_group_names.return_value = []
mock_exchange_auth_code_for_token.return_value = self.access_token
mock_DiscordClient.return_value.current_user.return_value = self.user_info
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = []
mock_DiscordClient.return_value.current_user.return_value = discord_user
mock_calculate_roles_for_user.return_value = RolesSet([]), None
mock_exception = HTTPError('error')
mock_exception.response = Mock()
mock_exception.response.status_code = 500
mock_DiscordClient.return_value.add_guild_member.side_effect = mock_exception
mock_create_bot_client.return_value.add_guild_member.side_effect = mock_exception
# when
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
# then
self.assertFalse(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.add_guild_member.called)
class TestOauthHelpers(TestCase):
class TestOauthHelpers(NoSocketsTestCase):
@patch(MODULE_PATH + '.managers.DISCORD_APP_ID', '123456')
def test_generate_bot_add_url(self):
bot_add_url = DiscordUser.objects.generate_bot_add_url()
auth_url = DiscordClient.OAUTH_BASE_URL
auth_url = DISCORD_OAUTH_BASE_URL
real_bot_add_url = (
f'{auth_url}?client_id=123456&scope=bot'
f'&permissions={DiscordUser.objects.BOT_PERMISSIONS}'
@@ -368,12 +399,12 @@ class TestOauthHelpers(TestCase):
def test_generate_oauth_redirect_url(self):
oauth_url = DiscordUser.objects.generate_oauth_redirect_url()
self.assertIn(DiscordClient.OAUTH_BASE_URL, oauth_url)
self.assertIn(DISCORD_OAUTH_BASE_URL, oauth_url)
self.assertIn('+'.join(DiscordUser.objects.SCOPES), oauth_url)
self.assertIn(DISCORD_APP_ID, oauth_url)
self.assertIn(urllib.parse.quote_plus(DISCORD_CALLBACK_URL), oauth_url)
@patch(MODULE_PATH + '.managers.OAuth2Session')
@patch(MODULE_PATH + '.managers.OAuth2Session', spec=True)
def test_process_callback_code(self, oauth):
instance = oauth.return_value
instance.fetch_token.return_value = {'access_token': 'mywonderfultoken'}
@@ -386,52 +417,13 @@ class TestOauthHelpers(TestCase):
self.assertEqual(kwargs['redirect_uri'], DISCORD_CALLBACK_URL)
self.assertTrue(instance.fetch_token.called)
args, kwargs = instance.fetch_token.call_args
self.assertEqual(args[0], DiscordClient.OAUTH_TOKEN_URL)
self.assertEqual(args[0], DISCORD_OAUTH_TOKEN_URL)
self.assertEqual(kwargs['client_secret'], DISCORD_APP_SECRET)
self.assertEqual(kwargs['code'], '12345')
self.assertEqual(token, 'mywonderfultoken')
class TestUserFormattedNick(TestCase):
def setUp(self):
self.user = AuthUtils.create_user(TEST_USER_NAME)
def test_return_nick_when_user_has_main(self):
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
result = DiscordUser.objects.user_formatted_nick(self.user)
expected = TEST_MAIN_NAME
self.assertEqual(result, expected)
def test_return_none_if_user_has_no_main(self):
result = DiscordUser.objects.user_formatted_nick(self.user)
self.assertIsNone(result)
class TestUserGroupNames(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.group_1 = Group.objects.create(name='Group 1')
cls.group_2 = Group.objects.create(name='Group 2')
def setUp(self):
self.user = AuthUtils.create_member(TEST_USER_NAME)
def test_return_groups_and_state_names_for_user(self):
self.user.groups.add(self.group_1)
result = DiscordUser.objects.user_group_names(self.user)
expected = ['Group 1', 'Member']
self.assertSetEqual(set(result), set(expected))
def test_return_state_only_if_user_has_no_groups(self):
result = DiscordUser.objects.user_group_names(self.user)
expected = ['Member']
self.assertSetEqual(set(result), set(expected))
class TestUserHasAccount(TestCase):
class TestUserHasAccount(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -453,59 +445,22 @@ class TestUserHasAccount(TestCase):
self.assertFalse(DiscordUser.objects.user_has_account('abc'))
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
@patch(MODULE_PATH + '.managers.logger')
class TestServerName(TestCase):
class TestOtherMethods(NoSocketsTestCase):
@patch(MODULE_PATH + '.managers.core_group_to_role', spec=True)
def test_should_call_group_to_role(self, mock_core_group_to_role):
# given
role = create_role(id=1, name="alpha", managed=False)
mock_core_group_to_role.return_value = role
# when
result = DiscordUser.objects.group_to_role(Mock())
# then
self.assertEqual(result["id"], 1)
self.assertEqual(result["name"], "alpha")
self.assertEqual(result["managed"], False)
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.user = AuthUtils.create_user(TEST_USER_NAME)
def test_returns_name_when_api_returns_it(self, mock_logger, mock_DiscordClient):
server_name = "El Dorado"
mock_DiscordClient.return_value.guild_name.return_value = server_name
self.assertEqual(DiscordUser.objects.server_name(), server_name)
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_http_error(
self, mock_logger, mock_DiscordClient
):
mock_exception = HTTPError('Test exception')
mock_exception.response = Mock(**{"status_code": 440})
mock_DiscordClient.return_value.guild_name.side_effect = mock_exception
self.assertEqual(DiscordUser.objects.server_name(), "")
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_service_error(
self, mock_logger, mock_DiscordClient
):
mock_DiscordClient.return_value.guild_name.side_effect = DiscordApiBackoff(1000)
self.assertEqual(DiscordUser.objects.server_name(), "")
self.assertFalse(mock_logger.warning.called)
def test_returns_empty_string_when_api_throws_unexpected_error(
self, mock_logger, mock_DiscordClient
):
mock_DiscordClient.return_value.guild_name.side_effect = RuntimeError
self.assertEqual(DiscordUser.objects.server_name(), "")
self.assertTrue(mock_logger.warning.called)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
class TestRoleForGroup(TestCase):
def test_return_role_if_found(self, mock_DiscordClient):
mock_DiscordClient.return_value.match_role_from_name.return_value = ROLE_ALPHA
group = Group.objects.create(name='alpha')
self.assertEqual(DiscordUser.objects.group_to_role(group), ROLE_ALPHA)
def test_return_empty_dict_if_not_found(self, mock_DiscordClient):
mock_DiscordClient.return_value.match_role_from_name.return_value = dict()
group = Group.objects.create(name='unknown')
self.assertEqual(DiscordUser.objects.group_to_role(group), dict())
@patch(MODULE_PATH + '.managers.core_server_name', spec=True)
def test_should_call_server_name(self, mock_core_server_name):
# when
DiscordUser.objects.server_name()
# then
self.assertTrue(mock_core_server_name.called)

View File

@@ -1,34 +1,27 @@
from unittest.mock import patch, Mock
from unittest.mock import Mock, patch
from requests.exceptions import HTTPError
from django.test import TestCase
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.groupmanagement.models import ReservedGroupName
from allianceauth.utils.testing import NoSocketsTestCase
from . import (
TEST_USER_NAME,
from ..discord_client import DiscordApiBackoff, RolesSet
from ..discord_client.tests.factories import (
TEST_USER_ID,
TEST_MAIN_NAME,
TEST_MAIN_ID,
MODULE_PATH,
ROLE_ALPHA,
ROLE_BRAVO,
ROLE_CHARLIE,
ROLE_CHARLIE_2,
ROLE_MIKE,
TEST_USER_NAME,
create_guild_member,
create_role,
)
from ..discord_client import DiscordClient, DiscordApiBackoff
from ..discord_client.tests import create_matched_role
from ..discord_client.tests.factories import create_user as create_guild_user
from ..models import DiscordUser
from ..utils import set_logger_to_file
from . import MODULE_PATH, TEST_MAIN_ID, TEST_MAIN_NAME
from .factories import create_discord_user, create_user
logger = set_logger_to_file(MODULE_PATH + '.models', __file__)
class TestBasicsAndHelpers(TestCase):
class TestBasicsAndHelpers(NoSocketsTestCase):
def test_str(self):
user = AuthUtils.create_user(TEST_USER_NAME)
@@ -43,8 +36,8 @@ class TestBasicsAndHelpers(TestCase):
self.assertEqual(repr(discord_user), expected)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
class TestUpdateNick(TestCase):
@patch(MODULE_PATH + '.models.default_bot_client', spec=True)
class TestUpdateNick(NoSocketsTestCase):
def setUp(self):
self.user = AuthUtils.create_user(TEST_USER_NAME)
@@ -52,119 +45,92 @@ class TestUpdateNick(TestCase):
user=self.user, uid=TEST_USER_ID
)
def test_can_update(self, mock_DiscordClient):
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
mock_DiscordClient.return_value.modify_guild_member.return_value = True
def test_can_update(self, mock_default_bot_client):
# given
AuthUtils.add_main_character_2(
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
)
mock_default_bot_client.modify_guild_member.return_value = True
# when
result = self.discord_user.update_nickname()
# then
self.assertTrue(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
def test_dont_update_if_user_has_no_main(self, mock_DiscordClient):
mock_DiscordClient.return_value.modify_guild_member.return_value = False
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
def test_dont_update_if_user_has_no_main(self, mock_default_bot_client):
# given
mock_default_bot_client.modify_guild_member.return_value = False
# when
result = self.discord_user.update_nickname()
# then
self.assertFalse(result)
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
def test_return_none_if_user_no_longer_a_member(self, mock_DiscordClient):
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
mock_DiscordClient.return_value.modify_guild_member.return_value = None
self.assertFalse(mock_default_bot_client.modify_guild_member.called)
def test_return_none_if_user_no_longer_a_member(self, mock_default_bot_client):
# given
AuthUtils.add_main_character_2(
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
)
mock_default_bot_client.modify_guild_member.return_value = None
# when
result = self.discord_user.update_nickname()
# then
self.assertIsNone(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
def test_return_false_if_api_returns_false(self, mock_DiscordClient):
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
mock_DiscordClient.return_value.modify_guild_member.return_value = False
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
def test_return_false_if_api_returns_false(self, mock_default_bot_client):
# given
AuthUtils.add_main_character_2(
self.user, TEST_MAIN_NAME, TEST_MAIN_ID, disconnect_signals=True
)
mock_default_bot_client.modify_guild_member.return_value = False
# when
result = self.discord_user.update_nickname()
# then
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
self.assertTrue(mock_default_bot_client.modify_guild_member.called)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
class TestUpdateUsername(TestCase):
@patch(MODULE_PATH + '.models.default_bot_client.guild_member', spec=True)
class TestUpdateUsername(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.user = AuthUtils.create_user(TEST_USER_NAME)
cls.user = create_user()
def setUp(self):
self.discord_user = DiscordUser.objects.create(
user=self.user,
uid=TEST_USER_ID,
username=TEST_MAIN_NAME,
discriminator='1234'
)
def test_can_update(self, mock_DiscordClient):
def test_can_update(self, mock_guild_member):
# given
discord_user = create_discord_user(user=self.user)
new_username = 'New name'
new_discriminator = '9876'
user_info = {
'user': {
'id': str(TEST_USER_ID),
'username': new_username,
'discriminator': new_discriminator,
}
}
mock_DiscordClient.return_value.guild_member.return_value = user_info
result = self.discord_user.update_username()
guild_user = create_guild_user(
username='New name', discriminator=new_discriminator
)
mock_guild_member.return_value = create_guild_member(user=guild_user)
# when
result = discord_user.update_username()
# then
self.assertTrue(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
self.discord_user.refresh_from_db()
self.assertEqual(self.discord_user.username, new_username)
self.assertEqual(self.discord_user.discriminator, new_discriminator)
self.assertTrue(mock_guild_member.called)
discord_user.refresh_from_db()
self.assertEqual(discord_user.username, new_username)
self.assertEqual(discord_user.discriminator, new_discriminator)
def test_return_none_if_user_no_longer_a_member(self, mock_DiscordClient):
mock_DiscordClient.return_value.guild_member.return_value = None
result = self.discord_user.update_username()
def test_return_none_if_user_no_longer_a_member(self, mock_guild_member):
# given
discord_user = create_discord_user(user=self.user)
mock_guild_member.return_value = None
# when
result = discord_user.update_username()
# then
self.assertIsNone(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
def test_return_false_if_api_returns_false(self, mock_DiscordClient):
mock_DiscordClient.return_value.guild_member.return_value = False
result = self.discord_user.update_username()
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
def test_return_false_if_api_returns_corrput_data_1(self, mock_DiscordClient):
mock_DiscordClient.return_value.guild_member.return_value = {'invalid': True}
result = self.discord_user.update_username()
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
def test_return_false_if_api_returns_corrput_data_2(self, mock_DiscordClient):
user_info = {
'user': {
'id': str(TEST_USER_ID),
'discriminator': '1234',
}
}
mock_DiscordClient.return_value.guild_member.return_value = user_info
result = self.discord_user.update_username()
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
def test_return_false_if_api_returns_corrput_data_3(self, mock_DiscordClient):
user_info = {
'user': {
'id': str(TEST_USER_ID),
'username': TEST_USER_NAME,
}
}
mock_DiscordClient.return_value.guild_member.return_value = user_info
result = self.discord_user.update_username()
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.guild_member.called)
self.assertTrue(mock_guild_member.called)
@patch(MODULE_PATH + '.models.notify')
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
class TestDeleteUser(TestCase):
@patch(MODULE_PATH + '.models.notify', spec=True)
@patch(MODULE_PATH + '.models.create_bot_client', spec=True)
class TestDeleteUser(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -176,272 +142,168 @@ class TestDeleteUser(TestCase):
user=self.user, uid=TEST_USER_ID
)
def test_can_delete_user(self, mock_DiscordClient, mock_notify):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
def test_can_delete_user(self, mock_create_bot_client, mock_notify):
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
# when
result = self.discord_user.delete_user()
# then
self.assertTrue(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
self.assertFalse(mock_notify.called)
def test_can_delete_user_and_notify_user(self, mock_DiscordClient, mock_notify):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
def test_can_delete_user_and_notify_user(self, mock_create_bot_client, mock_notify):
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
# when
result = self.discord_user.delete_user(notify_user=True)
# then
self.assertTrue(result)
self.assertTrue(mock_notify.called)
def test_can_delete_user_when_member_is_unknown(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
mock_DiscordClient.return_value.remove_guild_member.return_value = None
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = None
# when
result = self.discord_user.delete_user()
# then
self.assertTrue(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
self.assertFalse(mock_notify.called)
def test_return_false_when_api_fails(self, mock_DiscordClient, mock_notify):
mock_DiscordClient.return_value.remove_guild_member.return_value = False
def test_return_false_when_api_fails(self, mock_create_bot_client, mock_notify):
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = False
# when
result = self.discord_user.delete_user()
# then
self.assertFalse(result)
def test_dont_notify_if_user_was_already_deleted_and_return_none(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
mock_DiscordClient.return_value.remove_guild_member.return_value = None
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = None
DiscordUser.objects.get(pk=self.discord_user.pk).delete()
# when
result = self.discord_user.delete_user()
# then
self.assertIsNone(result)
self.assertFalse(
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
)
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
self.assertTrue(mock_create_bot_client.return_value.remove_guild_member.called)
self.assertFalse(mock_notify.called)
def test_raise_exception_on_api_backoff(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
# given
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
DiscordApiBackoff(999)
# when/then
with self.assertRaises(DiscordApiBackoff):
self.discord_user.delete_user()
def test_return_false_on_api_backoff_and_exception_handling_on(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
# given
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
DiscordApiBackoff(999)
# when
result = self.discord_user.delete_user(handle_api_exceptions=True)
# then
self.assertFalse(result)
def test_raise_exception_on_http_error(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
# given
mock_exception = HTTPError('error')
mock_exception.response = Mock()
mock_exception.response.status_code = 500
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
mock_exception
# when/then
with self.assertRaises(HTTPError):
self.discord_user.delete_user()
def test_return_false_on_http_error_and_exception_handling_on(
self, mock_DiscordClient, mock_notify
self, mock_create_bot_client, mock_notify
):
# given
mock_exception = HTTPError('error')
mock_exception.response = Mock()
mock_exception.response.status_code = 500
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
mock_create_bot_client.return_value.remove_guild_member.side_effect = \
mock_exception
# when
result = self.discord_user.delete_user(handle_api_exceptions=True)
# then
self.assertFalse(result)
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
class TestUpdateGroups(TestCase):
@patch(MODULE_PATH + '.models.default_bot_client', spec=True)
@patch(MODULE_PATH + '.models.calculate_roles_for_user', spec=True)
class TestUpdateGroups(NoSocketsTestCase):
def setUp(self):
self.user = AuthUtils.create_user(TEST_USER_NAME)
self.discord_user = DiscordUser.objects.create(
user=self.user, uid=TEST_USER_ID
)
self.guild_roles = [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
self.roles_requested = [
create_matched_role(ROLE_ALPHA), create_matched_role(ROLE_BRAVO)
]
user = AuthUtils.create_user(TEST_USER_NAME)
self.discord_user = DiscordUser.objects.create(user=user, uid=TEST_USER_ID)
def test_update_if_needed(
self,
mock_user_group_names,
mock_DiscordClient
):
roles_current = [1]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
result = self.discord_user.update_groups()
self.assertTrue(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
self.assertEqual(set(kwargs['role_ids']), {1, 2})
def test_should_update_and_preserve_managed_and_reserved_roles(
self,
mock_user_group_names,
mock_DiscordClient
def test_should_update_when_roles_have_changed(
self, mock_calculate_roles_for_user, mock_client
):
# given
roles_current = [1, 3, 4, 13]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = [
ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE, ROLE_CHARLIE_2
]
mock_DiscordClient.return_value.guild_member.return_value = {
'roles': roles_current
}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
ReservedGroupName.objects.create(
name="charlie", reason="dummy", created_by="xyz"
)
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), True
mock_client.modify_guild_member.return_value = True
# when
result = self.discord_user.update_groups()
# then
self.assertTrue(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
self.assertEqual(set(kwargs['role_ids']), {1, 2, 3, 4, 13})
self.assertTrue(mock_client.modify_guild_member.called)
def test_dont_update_if_not_needed(
self,
mock_user_group_names,
mock_DiscordClient
def test_should_not_update_when_roles_have_not_changed(
self, mock_calculate_roles_for_user, mock_client
):
roles_current = [1, 2, 13]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
# given
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), False
mock_client.modify_guild_member.return_value = True
# when
result = self.discord_user.update_groups()
# then
self.assertTrue(result)
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
self.assertFalse(mock_client.modify_guild_member.called)
def test_update_if_user_has_no_roles_on_discord(
self,
mock_user_group_names,
mock_DiscordClient
def test_should_not_update_when_user_not_guild_member(
self, mock_calculate_roles_for_user, mock_client
):
roles_current = []
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
result = self.discord_user.update_groups()
self.assertTrue(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
args, kwargs = mock_DiscordClient.return_value.modify_guild_member.call_args
self.assertEqual(set(kwargs['role_ids']), {1, 2})
def test_return_none_if_user_no_longer_a_member(
self,
mock_user_group_names,
mock_DiscordClient
):
mock_DiscordClient.return_value.guild_member.return_value = None
# given
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), None
mock_client.modify_guild_member.return_value = True
# when
result = self.discord_user.update_groups()
# then
self.assertIsNone(result)
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
self.assertFalse(mock_client.modify_guild_member.called)
def test_return_false_if_api_returns_false(
self,
mock_user_group_names,
mock_DiscordClient
def test_should_return_false_when_update_failed(
self, mock_calculate_roles_for_user, mock_client
):
roles_current = [1]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
mock_DiscordClient.return_value.modify_guild_member.return_value = False
# given
mock_calculate_roles_for_user.return_value = RolesSet([create_role()]), True
mock_client.modify_guild_member.return_value = False
# when
result = self.discord_user.update_groups()
# then
self.assertFalse(result)
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
def test_raise_exception_if_member_has_unknown_roles(
self,
mock_user_group_names,
mock_DiscordClient
):
roles_current = [99]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
with self.assertRaises(RuntimeError):
self.discord_user.update_groups()
def test_refresh_guild_roles_user_roles_dont_not_match(
self,
mock_user_group_names,
mock_DiscordClient
):
def my_guild_roles(guild_id, use_cache=True):
if use_cache:
return [ROLE_ALPHA, ROLE_BRAVO, ROLE_MIKE]
else:
return [ROLE_ALPHA, ROLE_BRAVO, ROLE_CHARLIE, ROLE_MIKE]
roles_current = [3]
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.side_effect = my_guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'roles': roles_current}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
result = self.discord_user.update_groups()
self.assertTrue(result)
self.assertEqual(mock_DiscordClient.return_value.guild_roles.call_count, 2)
def test_raise_exception_if_member_info_is_invalid(
self,
mock_user_group_names,
mock_DiscordClient
):
mock_user_group_names.return_value = []
mock_DiscordClient.return_value.match_or_create_roles_from_names\
.return_value = self.roles_requested
mock_DiscordClient.return_value.guild_roles.return_value = self.guild_roles
mock_DiscordClient.return_value.guild_member.return_value = \
{'user': 'dummy'}
mock_DiscordClient.return_value.modify_guild_member.return_value = True
with self.assertRaises(RuntimeError):
self.discord_user.update_groups()
self.assertTrue(mock_client.modify_guild_member.called)

View File

@@ -3,18 +3,18 @@ from unittest.mock import MagicMock, patch
from celery.exceptions import Retry
from requests.exceptions import HTTPError
from django.test import TestCase
from django.contrib.auth.models import Group
from django.test.utils import override_settings
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
from . import TEST_USER_NAME, TEST_USER_ID, TEST_MAIN_NAME, TEST_MAIN_ID
from ..models import DiscordUser
from ..discord_client import DiscordApiBackoff
from .. import tasks
from ..discord_client import DiscordApiBackoff
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
from ..models import DiscordUser
from ..utils import set_logger_to_file
from . import TEST_MAIN_ID, TEST_MAIN_NAME
MODULE_PATH = 'allianceauth.services.modules.discord.tasks'
logger = set_logger_to_file(MODULE_PATH, __file__)
@@ -22,7 +22,7 @@ logger = set_logger_to_file(MODULE_PATH, __file__)
@patch(MODULE_PATH + '.DiscordUser.update_groups')
@patch(MODULE_PATH + ".logger")
class TestUpdateGroups(TestCase):
class TestUpdateGroups(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -110,7 +110,7 @@ class TestUpdateGroups(TestCase):
@patch(MODULE_PATH + '.DiscordUser.update_nickname')
class TestUpdateNickname(TestCase):
class TestUpdateNickname(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -163,7 +163,7 @@ class TestUpdateNickname(TestCase):
@patch(MODULE_PATH + '.DiscordUser.update_username')
class TestUpdateUsername(TestCase):
class TestUpdateUsername(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -179,7 +179,7 @@ class TestUpdateUsername(TestCase):
@patch(MODULE_PATH + '.DiscordUser.delete_user')
class TestDeleteUser(TestCase):
class TestDeleteUser(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -213,7 +213,7 @@ class TestDeleteUser(TestCase):
@patch(MODULE_PATH + '.DiscordUser.update_groups')
class TestTaskPerformUserAction(TestCase):
class TestTaskPerformUserAction(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -236,7 +236,7 @@ class TestTaskPerformUserAction(TestCase):
@patch(MODULE_PATH + '.DiscordUser.objects.server_name')
@patch(MODULE_PATH + ".logger")
class TestTaskUpdateServername(TestCase):
class TestTaskUpdateServername(NoSocketsTestCase):
def test_normal(self, mock_logger, mock_server_name):
tasks.update_servername()
@@ -281,7 +281,7 @@ class TestTaskUpdateServername(TestCase):
@patch(MODULE_PATH + '.DiscordUser.objects.server_name')
class TestTaskPerformUsersAction(TestCase):
class TestTaskPerformUsersAction(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -300,8 +300,8 @@ class TestTaskPerformUsersAction(TestCase):
tasks._task_perform_users_action(mock_task, 'server_name')
@override_settings(CELERY_ALWAYS_EAGER=True)
class TestBulkTasks(TestCase):
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class TestBulkTasks(NoSocketsTestCase):
@classmethod
def setUpClass(cls):

View File

@@ -1,4 +1,5 @@
from unittest.mock import Mock, patch
from django.test import TestCase
from ..utils import clean_setting

View File

@@ -1,28 +1,28 @@
from unittest.mock import patch
from django.contrib.auth.models import User
from django.test import TestCase, RequestFactory
from django.test import RequestFactory
from django.urls import reverse
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.utils.testing import NoSocketsTestCase
from . import MODULE_PATH, add_permissions_to_members, TEST_USER_NAME, TEST_USER_ID
from ..discord_client import DiscordClient
from ..discord_client.tests.factories import TEST_USER_ID, TEST_USER_NAME
from ..models import DiscordUser
from ..utils import set_logger_to_file
from ..views import (
discord_callback,
reset_discord,
activate_discord,
deactivate_discord,
discord_add_bot,
activate_discord
discord_callback,
reset_discord,
)
from . import MODULE_PATH, add_permissions_to_members
logger = set_logger_to_file(MODULE_PATH + '.views', __file__)
class SetupClassMixin(TestCase):
class SetupClassMixin(NoSocketsTestCase):
@classmethod
def setUpClass(cls):
@@ -33,7 +33,7 @@ class SetupClassMixin(TestCase):
cls.services_url = reverse('services:services')
class TestActivateDiscord(SetupClassMixin, TestCase):
class TestActivateDiscord(SetupClassMixin, NoSocketsTestCase):
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_oauth_redirect_url')
def test_redirects_to_correct_url(self, mock_generate_oauth_redirect_url):
@@ -47,31 +47,37 @@ class TestActivateDiscord(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
class TestDeactivateDiscord(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.models.create_bot_client')
class TestDeactivateDiscord(SetupClassMixin, NoSocketsTestCase):
def setUp(self):
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
def test_when_successful_show_success_message(
self, mock_DiscordClient, mock_messages
self, mock_create_bot_client, mock_messages
):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
request = self.factory.get(reverse('discord:deactivate'))
request.user = self.user
# when
response = deactivate_discord(request)
# then
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, self.services_url)
self.assertTrue(mock_messages.success.called)
self.assertFalse(mock_messages.error.called)
def test_when_unsuccessful_show_error_message(
self, mock_DiscordClient, mock_messages
self, mock_create_bot_client, mock_messages
):
mock_DiscordClient.return_value.remove_guild_member.return_value = False
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = False
request = self.factory.get(reverse('discord:deactivate'))
request.user = self.user
# when
response = deactivate_discord(request)
# then
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, self.services_url)
self.assertFalse(mock_messages.success.called)
@@ -79,30 +85,36 @@ class TestDeactivateDiscord(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.managers.DiscordClient')
class TestResetDiscord(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.models.create_bot_client')
class TestResetDiscord(SetupClassMixin, NoSocketsTestCase):
def setUp(self):
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
def test_when_successful_redirect_to_activate(
self, mock_DiscordClient, mock_messages
self, mock_create_bot_client, mock_messages
):
mock_DiscordClient.return_value.remove_guild_member.return_value = True
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = True
request = self.factory.get(reverse('discord:reset'))
request.user = self.user
# when
response = reset_discord(request)
# then
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, reverse("discord:activate"))
self.assertFalse(mock_messages.error.called)
def test_when_unsuccessful_message_error_and_redirect_to_service(
self, mock_DiscordClient, mock_messages
self, mock_create_bot_client, mock_messages
):
mock_DiscordClient.return_value.remove_guild_member.return_value = False
# given
mock_create_bot_client.return_value.remove_guild_member.return_value = False
request = self.factory.get(reverse('discord:reset'))
request.user = self.user
# when
response = reset_discord(request)
# then
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, self.services_url)
self.assertTrue(mock_messages.error.called)
@@ -110,7 +122,7 @@ class TestResetDiscord(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.views.messages')
@patch(MODULE_PATH + '.views.DiscordUser.objects.add_user')
class TestDiscordCallback(SetupClassMixin, TestCase):
class TestDiscordCallback(SetupClassMixin, NoSocketsTestCase):
def setUp(self):
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
@@ -155,7 +167,7 @@ class TestDiscordCallback(SetupClassMixin, TestCase):
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_bot_add_url')
class TestDiscordAddBot(TestCase):
class TestDiscordAddBot(NoSocketsTestCase):
def test_add_bot(self, mock_generate_bot_add_url):
bot_url = 'https://www.example.com/bot'

View File

@@ -1,7 +1,8 @@
from django.contrib import admin
from .models import AuthTS, Teamspeak3User, StateGroup
from django.contrib.auth.models import Group
from .models import AuthTS, Teamspeak3User, StateGroup, TSgroup
from ...admin import ServicesUserAdmin
from allianceauth.groupmanagement.models import ReservedGroupName
@admin.register(Teamspeak3User)
@@ -25,6 +26,16 @@ class AuthTSgroupAdmin(admin.ModelAdmin):
fields = ('auth_group', 'ts_group')
filter_horizontal = ('ts_group',)
def formfield_for_foreignkey(self, db_field, request, **kwargs):
if db_field.name == 'auth_group':
kwargs['queryset'] = Group.objects.exclude(name__in=ReservedGroupName.objects.values_list('name', flat=True))
return super().formfield_for_foreignkey(db_field, request, **kwargs)
def formfield_for_manytomany(self, db_field, request, **kwargs):
if db_field.name == 'ts_group':
kwargs['queryset'] = TSgroup.objects.exclude(ts_group_name__in=ReservedGroupName.objects.values_list('name', flat=True))
return super().formfield_for_manytomany(db_field, request, **kwargs)
def _ts_group(self, obj):
return [x for x in obj.ts_group.all().order_by('ts_group_id')]

View File

@@ -4,6 +4,7 @@ from django.conf import settings
from .util.ts3 import TS3Server, TeamspeakError
from .models import TSgroup
from allianceauth.groupmanagement.models import ReservedGroupName
logger = logging.getLogger(__name__)
@@ -156,32 +157,25 @@ class Teamspeak3Manager:
logger.info(f"Removed user id {uid} from group id {groupid} on TS3 server.")
def _sync_ts_group_db(self):
logger.debug("_sync_ts_group_db function called.")
try:
remote_groups = self._group_list()
local_groups = TSgroup.objects.all()
logger.debug("Comparing remote groups to TSgroup objects: %s" % local_groups)
for key in remote_groups:
logger.debug(f"Typecasting remote_group value at position {key} to int: {remote_groups[key]}")
remote_groups[key] = int(remote_groups[key])
managed_groups = {g:int(remote_groups[g]) for g in remote_groups if g in set(remote_groups.keys()) - set(ReservedGroupName.objects.values_list('name', flat=True))}
remove = TSgroup.objects.exclude(ts_group_id__in=managed_groups.values())
if remove:
logger.debug(f"Deleting {remove.count()} TSgroup models: not found on server, or reserved name.")
remove.delete()
add = {g:managed_groups[g] for g in managed_groups if managed_groups[g] in set(managed_groups.values()) - set(TSgroup.objects.values_list("ts_group_id", flat=True))}
if add:
logger.debug(f"Adding {len(add)} new TSgroup models.")
models = [TSgroup(ts_group_name=name, ts_group_id=add[name]) for name in add]
TSgroup.objects.bulk_create(models)
for group in local_groups:
logger.debug("Checking local group %s" % group)
if group.ts_group_id not in remote_groups.values():
logger.debug(
f"Local group id {group.ts_group_id} not found on server. Deleting model {group}")
TSgroup.objects.filter(ts_group_id=group.ts_group_id).delete()
for key in remote_groups:
g = TSgroup(ts_group_id=remote_groups[key], ts_group_name=key)
q = TSgroup.objects.filter(ts_group_id=g.ts_group_id)
if not q:
logger.debug("Local group does not exist for TS group {}. Creating TSgroup model {}".format(
remote_groups[key], g))
g.save()
except TeamspeakError as e:
logger.error("Error occured while syncing TS group db: %s" % str(e))
except:
logger.exception("An unhandled exception has occured while syncing TS groups.")
logger.error(f"Error occurred while syncing TS group db: {str(e)}")
except Exception:
logger.exception(f"An unhandled exception has occurred while syncing TS groups.")
def add_user(self, user, fmt_name):
username_clean = self.__santatize_username(fmt_name[:30])
@@ -240,7 +234,7 @@ class Teamspeak3Manager:
logger.exception(f"Failed to delete user id {uid} from TS3 - received response {ret}")
return False
else:
logger.warn("User with id %s not found on TS3 server. Assuming succesful deletion." % uid)
logger.warning("User with id %s not found on TS3 server. Assuming succesful deletion." % uid)
return True
def check_user_exists(self, uid):
@@ -270,7 +264,8 @@ class Teamspeak3Manager:
addgroups.append(ts_groups[ts_group_key])
for user_ts_group_key in user_ts_groups:
if user_ts_groups[user_ts_group_key] not in ts_groups.values():
remgroups.append(user_ts_groups[user_ts_group_key])
if not ReservedGroupName.objects.filter(name=user_ts_group_key).exists():
remgroups.append(user_ts_groups[user_ts_group_key])
for g in addgroups:
logger.info(f"Adding Teamspeak user {userid} into group {g}")

View File

@@ -5,16 +5,18 @@ from django import urls
from django.contrib.auth.models import User, Group, Permission
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import signals
from django.contrib.admin import AdminSite
from allianceauth.tests.auth_utils import AuthUtils
from .auth_hooks import Teamspeak3Service
from .models import Teamspeak3User, AuthTS, TSgroup, StateGroup
from .tasks import Teamspeak3Tasks
from .signals import m2m_changed_authts_group, post_save_authts, post_delete_authts
from .admin import AuthTSgroupAdmin
from .manager import Teamspeak3Manager
from .util.ts3 import TeamspeakError
from allianceauth.authentication.models import State
from allianceauth.groupmanagement.models import ReservedGroupName
MODULE_PATH = 'allianceauth.services.modules.teamspeak3'
DEFAULT_AUTH_GROUP = 'Member'
@@ -315,6 +317,9 @@ class Teamspeak3SignalsTestCase(TestCase):
class Teamspeak3ManagerTestCase(TestCase):
@classmethod
def setUpTestData(cls):
cls.reserved = ReservedGroupName.objects.create(name='reserved', reason='tests', created_by='Bob, praise be!')
@staticmethod
def my_side_effect(*args, **kwargs):
@@ -334,8 +339,135 @@ class Teamspeak3ManagerTestCase(TestCase):
manager._server = server
# create test data
user = User.objects.create_user("dummy")
user.profile.state = State.objects.filter(name="Member").first()
user = AuthUtils.create_user("dummy")
AuthUtils.assign_state(user, AuthUtils.get_member_state())
# perform test
manager.add_user(user, "Dummy User")
@mock.patch.object(Teamspeak3Manager, '_get_userid')
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
def test_update_groups_add(self, remove, add, groups, userid):
"""Add to one group"""
userid.return_value = 1
groups.return_value = {'test': 1}
Teamspeak3Manager().update_groups(1, {'test': 1, 'dummy': 2})
self.assertEqual(add.call_count, 1)
self.assertEqual(remove.call_count, 0)
self.assertEqual(add.call_args[0][1], 2)
@mock.patch.object(Teamspeak3Manager, '_get_userid')
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
def test_update_groups_remove(self, remove, add, groups, userid):
"""Remove from one group"""
userid.return_value = 1
groups.return_value = {'test': '1', 'dummy': '2'}
Teamspeak3Manager().update_groups(1, {'test': 1})
self.assertEqual(add.call_count, 0)
self.assertEqual(remove.call_count, 1)
self.assertEqual(remove.call_args[0][1], 2)
@mock.patch.object(Teamspeak3Manager, '_get_userid')
@mock.patch.object(Teamspeak3Manager, '_user_group_list')
@mock.patch.object(Teamspeak3Manager, '_add_user_to_group')
@mock.patch.object(Teamspeak3Manager, '_remove_user_from_group')
def test_update_groups_remove_reserved(self, remove, add, groups, userid):
"""Remove from one group, but do not touch reserved group"""
userid.return_value = 1
groups.return_value = {'test': 1, 'dummy': 2, self.reserved.name: 3}
Teamspeak3Manager().update_groups(1, {'test': 1})
self.assertEqual(add.call_count, 0)
self.assertEqual(remove.call_count, 1)
self.assertEqual(remove.call_args[0][1], 2)
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_create(self, group_list):
"""Populate the list of all TSgroups"""
group_list.return_value = {'allowed':'1', 'also allowed':'2'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 2)
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_delete(self, group_list):
"""Populate the list of all TSgroups, and delete one which no longer exists"""
TSgroup.objects.create(ts_group_name='deleted', ts_group_id=3)
group_list.return_value = {'allowed': '1'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
self.assertFalse(TSgroup.objects.filter(ts_group_name='deleted').exists())
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_dont_create_reserved(self, group_list):
"""Populate the list of all TSgroups, ignoring a reserved group name"""
group_list.return_value = {'allowed': '1', 'reserved': '4'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_delete_reserved(self, group_list):
"""Populate the list of all TSgroups, deleting the TSgroup model for one which has become reserved"""
TSgroup.objects.create(ts_group_name='reserved', ts_group_id=4)
group_list.return_value = {'allowed': '1', 'reserved': '4'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
self.assertFalse(TSgroup.objects.filter(ts_group_name='reserved').exists())
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_partial_addition(self, group_list):
"""Some TSgroups already exist in database, add new ones"""
TSgroup.objects.create(ts_group_name='allowed', ts_group_id=1)
group_list.return_value = {'allowed': '1', 'also allowed': '2'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 2)
@mock.patch.object(Teamspeak3Manager, '_group_list')
def test_sync_group_db_partial_removal(self, group_list):
"""One TSgroup has been deleted on server, so remove its model"""
TSgroup.objects.create(ts_group_name='allowed', ts_group_id=1)
TSgroup.objects.create(ts_group_name='also allowed', ts_group_id=2)
group_list.return_value = {'allowed': '1'}
Teamspeak3Manager()._sync_ts_group_db()
self.assertEqual(TSgroup.objects.all().count(), 1)
class MockRequest:
pass
class MockSuperUser:
def has_perm(self, perm, obj=None):
return True
request = MockRequest()
request.user = MockSuperUser()
class Teamspeak3AdminTestCase(TestCase):
@classmethod
def setUpTestData(cls):
cls.site = AdminSite()
cls.admin = AuthTSgroupAdmin(AuthTS, cls.site)
cls.group = Group.objects.create(name='test')
cls.ts_group = TSgroup.objects.create(ts_group_name='test')
def test_field_queryset_no_reserved_names(self):
"""Ensure all groups are listed when no reserved names"""
form = self.admin.get_form(request)
self.assertQuerysetEqual(form.base_fields['auth_group']._get_queryset(), Group.objects.all())
self.assertQuerysetEqual(form.base_fields['ts_group']._get_queryset(), TSgroup.objects.all())
def test_field_queryset_reserved_names(self):
"""Ensure reserved group names are filtered out"""
ReservedGroupName.objects.bulk_create([ReservedGroupName(name='test', reason='tests', created_by='Bob')])
form = self.admin.get_form(request)
self.assertQuerysetEqual(form.base_fields['auth_group']._get_queryset(), Group.objects.none())
self.assertQuerysetEqual(form.base_fields['ts_group']._get_queryset(), TSgroup.objects.none())

View File

@@ -1,4 +1,5 @@
import logging
from functools import partial
from django.contrib.auth.models import User, Group, Permission
from django.core.exceptions import ObjectDoesNotExist
@@ -8,7 +9,7 @@ from django.db.models.signals import pre_delete
from django.db.models.signals import pre_save
from django.dispatch import receiver
from .hooks import ServicesHook
from .tasks import disable_user
from .tasks import disable_user, update_groups_for_user
from allianceauth.authentication.models import State, UserProfile
from allianceauth.authentication.signals import state_changed
@@ -19,21 +20,27 @@ logger = logging.getLogger(__name__)
@receiver(m2m_changed, sender=User.groups.through)
def m2m_changed_user_groups(sender, instance, action, *args, **kwargs):
logger.debug(f"Received m2m_changed from {instance} groups with action {action}")
def trigger_service_group_update():
logger.debug("Triggering service group update for %s" % instance)
# Iterate through Service hooks
for svc in ServicesHook.get_services():
try:
svc.validate_user(instance)
svc.update_groups(instance)
except:
logger.exception(f'Exception running update_groups for services module {svc} on user {instance}')
if instance.pk and (action == "post_add" or action == "post_remove" or action == "post_clear"):
logger.debug("Waiting for commit to trigger service group update for %s" % instance)
transaction.on_commit(trigger_service_group_update)
logger.debug(
"%s: Received m2m_changed from groups with action %s", instance, action
)
if instance.pk and (
action == "post_add" or action == "post_remove" or action == "post_clear"
):
if isinstance(instance, User):
logger.debug(
"Waiting for commit to trigger service group update for %s", instance
)
transaction.on_commit(partial(update_groups_for_user.delay, instance.pk))
elif (
isinstance(instance, Group)
and kwargs.get("model") is User
and "pk_set" in kwargs
):
for user_pk in kwargs["pk_set"]:
logger.debug(
"%s: Waiting for commit to trigger service group update for user", user_pk
)
transaction.on_commit(partial(update_groups_for_user.delay, user_pk))
@receiver(m2m_changed, sender=User.user_permissions.through)

View File

@@ -47,3 +47,20 @@ def disable_user(user):
for svc in ServicesHook.get_services():
if svc.service_active_for_user(user):
svc.delete_user(user)
@shared_task
def update_groups_for_user(user_pk: int) -> None:
"""Update groups for all services registered to a user."""
user = User.objects.get(pk=user_pk)
logger.debug("%s: Triggering service group update for user", user)
for svc in ServicesHook.get_services():
try:
svc.validate_user(user)
svc.update_groups(user)
except Exception:
logger.exception(
'Exception running update_groups for services module %s on user %s',
svc,
user
)

View File

@@ -1,7 +1,7 @@
from copy import deepcopy
from unittest import mock
from django.test import TestCase
from django.test import override_settings, TestCase, TransactionTestCase
from django.contrib.auth.models import Group, Permission
from allianceauth.authentication.models import State
@@ -9,6 +9,9 @@ from allianceauth.eveonline.models import EveCharacter
from allianceauth.tests.auth_utils import AuthUtils
MODULE_PATH = 'allianceauth.services.signals'
class ServicesSignalsTestCase(TestCase):
def setUp(self):
self.member = AuthUtils.create_user('auth_member', disconnect_signals=True)
@@ -17,17 +20,12 @@ class ServicesSignalsTestCase(TestCase):
)
self.none_user = AuthUtils.create_user('none_user', disconnect_signals=True)
@mock.patch('allianceauth.services.signals.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook')
def test_m2m_changed_user_groups(self, services_hook, transaction):
@mock.patch(MODULE_PATH + '.transaction', spec=True)
@mock.patch(MODULE_PATH + '.update_groups_for_user', spec=True)
def test_m2m_changed_user_groups(self, update_groups_for_user, transaction):
"""
Test that update_groups hook function is called on user groups change
"""
svc = mock.Mock()
svc.update_groups.return_value = None
svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc]
# Overload transaction.on_commit so everything happens synchronously
transaction.on_commit = lambda fn: fn()
@@ -39,17 +37,11 @@ class ServicesSignalsTestCase(TestCase):
self.member.save()
# Assert
self.assertTrue(services_hook.get_services.called)
self.assertTrue(update_groups_for_user.delay.called)
args, _ = update_groups_for_user.delay.call_args
self.assertEqual(self.member.pk, args[0])
self.assertTrue(svc.update_groups.called)
args, kwargs = svc.update_groups.call_args
self.assertEqual(self.member, args[0])
self.assertTrue(svc.validate_user.called)
args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.disable_user')
@mock.patch(MODULE_PATH + '.disable_user')
def test_pre_delete_user(self, disable_user):
"""
Test that disable_member is called when a user is deleted
@@ -60,7 +52,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args
self.assertEqual(self.none_user, args[0])
@mock.patch('allianceauth.services.signals.disable_user')
@mock.patch(MODULE_PATH + '.disable_user')
def test_pre_save_user_inactivation(self, disable_user):
"""
Test a user set inactive has disable_member called
@@ -72,7 +64,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.disable_user')
@mock.patch(MODULE_PATH + '.disable_user')
def test_disable_services_on_loss_of_main_character(self, disable_user):
"""
Test a user set inactive has disable_member called
@@ -84,8 +76,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = disable_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.transaction')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_group_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType
svc = mock.Mock()
@@ -116,8 +108,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.transaction')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_user_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType
svc = mock.Mock()
@@ -145,8 +137,8 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.transaction')
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.transaction')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_m2m_changed_user_state_permissions(self, services_hook, transaction):
from django.contrib.contenttypes.models import ContentType
svc = mock.Mock()
@@ -180,7 +172,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.validate_user.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_state_changed_services_validation_and_groups_update(self, services_hook):
"""Test a user changing state has service accounts validated and groups updated
"""
@@ -206,8 +198,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.update_groups.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_state_changed_services_validation_and_groups_update_1(self, services_hook):
"""Test a user changing main has service accounts validated and sync updated
"""
@@ -238,7 +229,7 @@ class ServicesSignalsTestCase(TestCase):
args, kwargs = svc.sync_nickname.call_args
self.assertEqual(self.member, args[0])
@mock.patch('allianceauth.services.signals.ServicesHook')
@mock.patch(MODULE_PATH + '.ServicesHook')
def test_state_changed_services_validation_and_groups_update_2(self, services_hook):
"""Test a user changing main has service does not have accounts validated
and sync updated if the new main is equal to the old main
@@ -260,3 +251,71 @@ class ServicesSignalsTestCase(TestCase):
self.assertFalse(services_hook.get_services.called)
self.assertFalse(svc.validate_user.called)
self.assertFalse(svc.sync_nickname.called)
@mock.patch(
"allianceauth.services.modules.mumble.auth_hooks.MumbleService.update_groups"
)
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class TestUserGroupBulkUpdate(TransactionTestCase):
def test_should_run_user_service_check_when_group_added_to_user(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group = Group.objects.create(name="Group")
mock_update_groups.reset_mock()
# when
user.groups.add(group)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_multiple_groups_are_added_to_user(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group_1 = Group.objects.create(name="Group 1")
group_2 = Group.objects.create(name="Group 2")
mock_update_groups.reset_mock()
# when
user.groups.add(group_1, group_2)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_user_added_to_group(
self, mock_update_groups
):
# given
user = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user, "Bruce Wayne", 1001)
group = Group.objects.create(name="Group")
mock_update_groups.reset_mock()
# when
group.user_set.add(user)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user})
def test_should_run_user_service_check_when_multiple_users_are_added_to_group(
self, mock_update_groups
):
# given
user_1 = AuthUtils.create_user("Bruce Wayne")
AuthUtils.add_main_character_2(user_1, "Bruce Wayne", 1001)
user_2 = AuthUtils.create_user("Peter Parker")
AuthUtils.add_main_character_2(user_2, "Peter Parker", 1002)
user_3 = AuthUtils.create_user("Lex Luthor")
AuthUtils.add_main_character_2(user_3, "Lex Luthor", 1011)
group = Group.objects.create(name="Group")
user_1.groups.add(group)
mock_update_groups.reset_mock()
# when
group.user_set.add(user_2, user_3)
# then
users_updated = {obj[0][0] for obj in mock_update_groups.call_args_list}
self.assertSetEqual(users_updated, {user_2, user_3})

View File

@@ -3,32 +3,50 @@ from unittest import mock
from celery_once import AlreadyQueued
from django.core.cache import cache
from django.test import TestCase
from django.test import override_settings, TestCase
from allianceauth.tests.auth_utils import AuthUtils
from allianceauth.services.tasks import validate_services
from allianceauth.services.tasks import validate_services, update_groups_for_user
from ..tasks import DjangoBackend
@override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True)
class ServicesTasksTestCase(TestCase):
def setUp(self):
self.member = AuthUtils.create_user('auth_member')
@mock.patch('allianceauth.services.tasks.ServicesHook')
def test_validate_services(self, services_hook):
# given
svc = mock.Mock()
svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc]
# when
validate_services.delay(self.member.pk)
# then
self.assertTrue(services_hook.get_services.called)
self.assertTrue(svc.validate_user.called)
args, kwargs = svc.validate_user.call_args
args, _ = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) # Assert correct user is passed to service hook function
@mock.patch('allianceauth.services.tasks.ServicesHook')
def test_update_groups_for_user(self, services_hook):
# given
svc = mock.Mock()
svc.validate_user.return_value = None
services_hook.get_services.return_value = [svc]
# when
update_groups_for_user.delay(self.member.pk)
# then
self.assertTrue(services_hook.get_services.called)
self.assertTrue(svc.validate_user.called)
args, _ = svc.validate_user.call_args
self.assertEqual(self.member, args[0]) # Assert correct user
self.assertTrue(svc.update_groups.called)
args, _ = svc.update_groups.call_args
self.assertEqual(self.member, args[0]) # Assert correct user
class TestDjangoBackend(TestCase):

View File

@@ -267,7 +267,9 @@ ESC to cancel{% endblocktrans %}"id="blah"></i></th>
"targets": [4, 5],
"type": "num"
}
]
],
"stateSave": true,
"stateDuration": 0
});
// tooltip

View File

@@ -95,6 +95,11 @@ ul.list-group.list-group-horizontal > li.list-group-item {
.table-aa > tbody > tr:last-child {
border-bottom: none;
}
.task-status-progress-bar {
font-size: 15px!important;
line-height: normal!important;
}
}
/* highlight active menu items

View File

@@ -1,58 +1,20 @@
$(document).ready(function () {
'use strict';
/**
* check time
* @param i
* @returns {string}
*/
let checkTime = function (i) {
if (i < 10) {
i = '0' + i;
}
return i;
};
/**
* render a JS clock for Eve Time
* @param element
* @param utcOffset
*/
let renderClock = function (element, utcOffset) {
let today = new Date();
let h = today.getUTCHours();
let m = today.getUTCMinutes();
h = h + utcOffset;
if (h > 24) {
h = h - 24;
}
if (h < 0) {
h = h + 24;
}
h = checkTime(h);
m = checkTime(m);
const renderClock = function (element) {
const datetimeNow = new Date();
const h = String(datetimeNow.getUTCHours()).padStart(2, '0');
const m = String(datetimeNow.getUTCMinutes()).padStart(2, '0');
element.html(h + ':' + m);
setTimeout(function () {
renderClock(element, 0);
}, 500);
};
/**
* functions that need to be executed on load
*/
let init = function () {
renderClock($('.eve-time-wrapper .eve-time-clock'), 0);
};
/**
* start the show
*/
init();
// Start the Eve time clock in the top menu bar
setInterval(function () {
renderClock($('.eve-time-wrapper .eve-time-clock'));
}, 500);
});

View File

@@ -0,0 +1,12 @@
{% load humanize %}
{% load admin_status %}
<div
class="progress-bar progress-bar-{{ level }} task-status-progress-bar"
role="progressbar"
aria-valuenow="{% decimal_widthratio tasks_count tasks_total 100 %}"
aria-valuemin="0"
aria-valuemax="100"
style="width: {% decimal_widthratio tasks_count tasks_total 100 %}%;">
<p style="margin-top:5px;">{% widthratio tasks_count tasks_total 100 %}%</p>
</div>

View File

@@ -1,4 +1,6 @@
{% load i18n %}
{% load humanize %}
<div class="col-sm-12">
<div class="row vertical-flexbox-row2">
<div class="col-sm-6">
@@ -75,29 +77,25 @@
<div class="panel panel-primary" style="height:50%;">
<div class="panel-heading text-center"><h3 class="panel-title">{% translate "Task Queue" %}</h3></div>
<div class="panel-body flex-center-horizontal">
<div class="progress" style="height: 21px;">
<div class="progress-bar
{% if task_queue_length > 500 %}
progress-bar-danger
{% elif task_queue_length > 100 %}
progress-bar-warning
{% else %}
progress-bar-success
{% endif %}
" role="progressbar" aria-valuenow="{% widthratio task_queue_length 500 100 %}"
aria-valuemin="0" aria-valuemax="100"
style="width: {% widthratio task_queue_length 500 100 %}%;">
</div>
<p>
{% blocktranslate with total=tasks_total|intcomma latest=earliest_task|timesince|default:"?" %}
Status of {{ total }} processed tasks • last {{ latest }}
{% endblocktranslate %}
</p>
<div
class="progress"
style="height: 21px;"
title="{{ tasks_succeeded|intcomma }} succeeded, {{ tasks_retried|intcomma }} retried, {{ tasks_failed|intcomma }} failed"
>
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="suceeded" level="success" tasks_count=tasks_succeeded %}
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="retried" level="info" tasks_count=tasks_retried %}
{% include "allianceauth/admin-status/celery_bar_partial.html" with label="failed" level="danger" tasks_count=tasks_failed %}
</div>
{% if task_queue_length < 0 %}
{% translate "Error retrieving task queue length" %}
{% else %}
{% blocktrans trimmed count tasks=task_queue_length %}
{{ tasks }} task
{% plural %}
{{ tasks }} tasks
{% endblocktrans %}
{% endif %}
<p>
{% blocktranslate with queue_length=task_queue_length|default_if_none:"?"|intcomma %}
{{ queue_length }} queued tasks
{% endblocktranslate %}
</p>
</div>
</div>
</div>

View File

@@ -0,0 +1,3 @@
{% load static %}
<script src="{% static 'js/eve-time.js' %}"></script>

View File

@@ -0,0 +1,3 @@
{% load static %}
<script type="application/javascript" src="{% static 'js/filterDropDown/filterDropDown.min.js' %}"></script>

View File

@@ -0,0 +1,3 @@
{% load static %}
<script src="{% static 'js/refresh_notifications.js' %}"></script>

View File

@@ -0,0 +1,3 @@
{% load static %}
<script src="{% static 'js/timers.js' %}"></script>

Some files were not shown because too many files have changed in this diff Show More