mirror of
https://gitlab.com/allianceauth/allianceauth.git
synced 2026-02-11 01:26:22 +01:00
Discord service major overhaul
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
from django import forms
|
||||
from django.contrib import admin
|
||||
from django.db.models.functions import Lower
|
||||
from django.urls import reverse
|
||||
from django.utils.html import format_html
|
||||
|
||||
from allianceauth import hooks
|
||||
from allianceauth.eveonline.models import EveCharacter
|
||||
from allianceauth.authentication.admin import user_profile_pic, \
|
||||
user_username, user_main_organization, MainCorporationsFilter,\
|
||||
from allianceauth.authentication.admin import (
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
user_main_organization,
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter
|
||||
)
|
||||
|
||||
from .models import NameFormatConfig
|
||||
|
||||
@@ -25,16 +25,24 @@ class ServicesUserAdmin(admin.ModelAdmin):
|
||||
list_select_related = True
|
||||
list_display = (
|
||||
user_profile_pic,
|
||||
user_username,
|
||||
user_username,
|
||||
'_state',
|
||||
user_main_organization,
|
||||
'_date_joined'
|
||||
)
|
||||
list_filter = (
|
||||
'user__profile__state',
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter,
|
||||
'user__date_joined'
|
||||
'user__date_joined',
|
||||
)
|
||||
|
||||
def _state(self, obj):
|
||||
return obj.user.profile.state.name
|
||||
|
||||
_state.short_description = 'state'
|
||||
_state.admin_order_field = 'user__profile__state__name'
|
||||
|
||||
def _date_joined(self, obj):
|
||||
return obj.user.date_joined
|
||||
|
||||
@@ -45,7 +53,8 @@ class ServicesUserAdmin(admin.ModelAdmin):
|
||||
class NameFormatConfigForm(forms.ModelForm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(NameFormatConfigForm, self).__init__(*args, **kwargs)
|
||||
SERVICE_CHOICES = [(s.name, s.name) for h in hooks.get_hooks('services_hook') for s in [h()]]
|
||||
SERVICE_CHOICES = \
|
||||
[(s.name, s.name) for h in hooks.get_hooks('services_hook') for s in [h()]]
|
||||
if self.instance.id:
|
||||
current_choice = (self.instance.service_name, self.instance.service_name)
|
||||
if current_choice not in SERVICE_CHOICES:
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
default_app_config = 'allianceauth.services.modules.discord.apps.DiscordServiceConfig'
|
||||
default_app_config = 'allianceauth.services.modules.discord.apps.DiscordServiceConfig' # noqa
|
||||
|
||||
__title__ = 'Discord Service'
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
import logging
|
||||
|
||||
from django.contrib import admin
|
||||
|
||||
from .models import DiscordUser
|
||||
from . import __title__
|
||||
from ...admin import ServicesUserAdmin
|
||||
from .models import DiscordUser
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
@admin.register(DiscordUser)
|
||||
class DiscordUserAdmin(ServicesUserAdmin):
|
||||
list_display = ServicesUserAdmin.list_display + ('_uid',)
|
||||
search_fields = ServicesUserAdmin.search_fields + ('uid', )
|
||||
search_fields = ServicesUserAdmin.search_fields + ('uid', 'username')
|
||||
list_display = ServicesUserAdmin.list_display + ('activated', '_username', '_uid')
|
||||
list_filter = ServicesUserAdmin.list_filter + ('activated',)
|
||||
ordering = ('-activated',)
|
||||
|
||||
def _uid(self, obj):
|
||||
return obj.uid
|
||||
@@ -15,3 +24,11 @@ class DiscordUserAdmin(ServicesUserAdmin):
|
||||
_uid.short_description = 'Discord ID (UID)'
|
||||
_uid.admin_order_field = 'uid'
|
||||
|
||||
def _username(self, obj):
|
||||
if obj.username and obj.discriminator:
|
||||
return f'{obj.username}#{obj.discriminator}'
|
||||
else:
|
||||
return ''
|
||||
|
||||
_username.short_description = 'Discord Username'
|
||||
_username.admin_order_field = 'username'
|
||||
|
||||
17
allianceauth/services/modules/discord/app_settings.py
Normal file
17
allianceauth/services/modules/discord/app_settings.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .utils import clean_setting
|
||||
|
||||
|
||||
DISCORD_APP_ID = clean_setting('DISCORD_APP_ID', '')
|
||||
DISCORD_APP_SECRET = clean_setting('DISCORD_APP_SECRET', '')
|
||||
DISCORD_BOT_TOKEN = clean_setting('DISCORD_BOT_TOKEN', '')
|
||||
DISCORD_CALLBACK_URL = clean_setting('DISCORD_CALLBACK_URL', '')
|
||||
DISCORD_GUILD_ID = clean_setting('DISCORD_GUILD_ID', '')
|
||||
|
||||
# max retries of tasks after an error occurred
|
||||
DISCORD_TASKS_MAX_RETRIES = clean_setting('DISCORD_TASKS_MAX_RETRIES', 3)
|
||||
|
||||
# Pause in seconds until next retry for tasks after the API returned an error
|
||||
DISCORD_TASKS_RETRY_PAUSE = clean_setting('DISCORD_TASKS_RETRY_PAUSE', 60)
|
||||
|
||||
# automatically sync Discord users names to user's main character name when created
|
||||
DISCORD_SYNC_NAMES = clean_setting('DISCORD_SYNC_NAMES', False)
|
||||
@@ -1,17 +1,26 @@
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.template.loader import render_to_string
|
||||
from django.conf import settings
|
||||
|
||||
from allianceauth import hooks
|
||||
from allianceauth.services.hooks import ServicesHook
|
||||
from .tasks import DiscordTasks
|
||||
from .urls import urlpatterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .models import DiscordUser
|
||||
from .urls import urlpatterns
|
||||
from .utils import LoggerAddTag
|
||||
from . import tasks, __title__
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
# Default priority for single tasks like update group and sync nickname
|
||||
SINGLE_TASK_PRIORITY = 3
|
||||
|
||||
|
||||
class DiscordService(ServicesHook):
|
||||
"""Service for managing a Discord server with Auth"""
|
||||
|
||||
def __init__(self):
|
||||
ServicesHook.__init__(self)
|
||||
self.urlpatterns = urlpatterns
|
||||
@@ -20,36 +29,85 @@ class DiscordService(ServicesHook):
|
||||
self.access_perm = 'discord.access_discord'
|
||||
self.name_format = '{character_name}'
|
||||
|
||||
def delete_user(self, user, notify_user=False):
|
||||
logger.debug('Deleting user %s %s account' % (user, self.name))
|
||||
return DiscordTasks.delete_user(user, notify_user=notify_user)
|
||||
def delete_user(self, user: User, notify_user: bool = False) -> None:
|
||||
if self.user_has_account(user):
|
||||
logger.debug('Deleting user %s %s account', user, self.name)
|
||||
tasks.delete_user.apply_async(
|
||||
kwargs={'user_pk': user.pk}, priority=SINGLE_TASK_PRIORITY
|
||||
)
|
||||
|
||||
def render_services_ctrl(self, request):
|
||||
if self.user_has_account(request.user):
|
||||
user_has_account = True
|
||||
username = request.user.discord.username
|
||||
discriminator = request.user.discord.discriminator
|
||||
if username and discriminator:
|
||||
discord_username = f'{username}#{discriminator}'
|
||||
else:
|
||||
discord_username = ''
|
||||
else:
|
||||
discord_username = ''
|
||||
user_has_account = False
|
||||
|
||||
def update_groups(self, user):
|
||||
logger.debug('Processing %s groups for %s' % (self.name, user))
|
||||
if DiscordTasks.has_account(user):
|
||||
DiscordTasks.update_groups.delay(user.pk)
|
||||
|
||||
def validate_user(self, user):
|
||||
logger.debug('Validating user %s %s account' % (user, self.name))
|
||||
if DiscordTasks.has_account(user) and not self.service_active_for_user(user):
|
||||
self.delete_user(user, notify_user=True)
|
||||
|
||||
def sync_nickname(self, user):
|
||||
logger.debug('Syncing %s nickname for user %s' % (self.name, user))
|
||||
DiscordTasks.update_nickname.apply_async(args=[user.pk], countdown=5)
|
||||
|
||||
def update_all_groups(self):
|
||||
logger.debug('Update all %s groups called' % self.name)
|
||||
DiscordTasks.update_all_groups.delay()
|
||||
return render_to_string(
|
||||
self.service_ctrl_template,
|
||||
{
|
||||
'server_name': DiscordUser.objects.server_name(),
|
||||
'user_has_account': user_has_account,
|
||||
'discord_username': discord_username
|
||||
},
|
||||
request=request
|
||||
)
|
||||
|
||||
def service_active_for_user(self, user):
|
||||
return user.has_perm(self.access_perm)
|
||||
|
||||
def render_services_ctrl(self, request):
|
||||
return render_to_string(self.service_ctrl_template, {
|
||||
'discord_uid': request.user.discord.uid if DiscordTasks.has_account(request.user) else None,
|
||||
'DISCORD_SERVER_ID': getattr(settings, 'DISCORD_GUILD_ID', ''),
|
||||
}, request=request)
|
||||
def sync_nickname(self, user):
|
||||
logger.debug('Syncing %s nickname for user %s', self.name, user)
|
||||
if self.user_has_account(user):
|
||||
tasks.update_nickname.apply_async(
|
||||
kwargs={'user_pk': user.pk}, priority=SINGLE_TASK_PRIORITY
|
||||
)
|
||||
|
||||
def sync_nicknames_bulk(self, users: list):
|
||||
"""Sync nickname for a list of users in bulk.
|
||||
Preferred over sync_nickname(), because it will not break the rate limit
|
||||
"""
|
||||
logger.debug(
|
||||
'Syncing %s nicknames in bulk for %d users', self.name, len(users)
|
||||
)
|
||||
user_pks = [user.pk for user in users]
|
||||
tasks.update_nicknames_bulk.delay(user_pks)
|
||||
|
||||
def update_all_groups(self):
|
||||
logger.debug('Update all %s groups called', self.name)
|
||||
tasks.update_all_groups.delay()
|
||||
|
||||
def update_groups(self, user):
|
||||
logger.debug('Processing %s groups for %s', self.name, user)
|
||||
if self.user_has_account(user):
|
||||
tasks.update_groups.apply_async(
|
||||
kwargs={'user_pk': user.pk}, priority=SINGLE_TASK_PRIORITY
|
||||
)
|
||||
|
||||
def update_groups_bulk(self, users: list):
|
||||
"""Updates groups for a list of users in bulk.
|
||||
Preferred over update_groups(), because it will not break the rate limit
|
||||
"""
|
||||
logger.debug(
|
||||
'Processing %s groups in bulk for %d users', self.name, len(users)
|
||||
)
|
||||
user_pks = [user.pk for user in users]
|
||||
tasks.update_groups_bulk.delay(user_pks)
|
||||
|
||||
@staticmethod
|
||||
def user_has_account(user: User) -> bool:
|
||||
return DiscordUser.objects.user_has_account(user)
|
||||
|
||||
def validate_user(self, user):
|
||||
logger.debug('Validating user %s %s account', user, self.name)
|
||||
if self.user_has_account(user) and not self.service_active_for_user(user):
|
||||
self.delete_user(user, notify_user=True)
|
||||
|
||||
|
||||
@hooks.register('services_hook')
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
from .client import DiscordClient # noqa
|
||||
from .exceptions import DiscordApiBackoff # noqa
|
||||
@@ -0,0 +1,40 @@
|
||||
from ..utils import clean_setting
|
||||
|
||||
|
||||
# Base URL for all API calls. Must end with /.
|
||||
DISCORD_API_BASE_URL = clean_setting(
|
||||
'DISCORD_API_BASE_URL', 'https://discordapp.com/api/'
|
||||
)
|
||||
|
||||
# Low level timeout for requests to the Discord API in ms
|
||||
DISCORD_API_TIMEOUT = clean_setting(
|
||||
'DISCORD_API_TIMEOUT', 5000
|
||||
)
|
||||
|
||||
# Base authorization URL for Discord Oauth
|
||||
DISCORD_OAUTH_BASE_URL = clean_setting(
|
||||
'DISCORD_OAUTH_BASE_URL', 'https://discordapp.com/api/oauth2/authorize'
|
||||
)
|
||||
|
||||
# Base authorization URL for Discord Oauth
|
||||
DISCORD_OAUTH_TOKEN_URL = clean_setting(
|
||||
'DISCORD_OAUTH_TOKEN_URL', 'https://discordapp.com/api/oauth2/token'
|
||||
)
|
||||
|
||||
# How long the Discord guild names retrieved from the server are
|
||||
# caches locally in milliseconds.
|
||||
DISCORD_GUILD_NAME_CACHE_MAX_AGE = clean_setting(
|
||||
'DISCORD_GUILD_NAME_CACHE_MAX_AGE', 3600 * 2 * 1000
|
||||
)
|
||||
|
||||
# How long Discord roles retrieved from the server are caches locally in milliseconds.
|
||||
DISCORD_ROLES_CACHE_MAX_AGE = clean_setting(
|
||||
'DISCORD_ROLES_CACHE_MAX_AGE', 3600 * 2 * 1000
|
||||
)
|
||||
|
||||
# Turns off creation of new roles. In case the rate limit for creating roles is
|
||||
# exhausted, this setting allows the Discord service to continue to function
|
||||
# and wait out the reset. Rate limit is about 250 per 48 hrs.
|
||||
DISCORD_DISABLE_ROLE_CREATION = clean_setting(
|
||||
'DISCORD_DISABLE_ROLE_CREATION', False
|
||||
)
|
||||
690
allianceauth/services/modules/discord/discord_client/client.py
Normal file
690
allianceauth/services/modules/discord/discord_client/client.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from hashlib import md5
|
||||
import logging
|
||||
from time import sleep
|
||||
from urllib.parse import urljoin
|
||||
from uuid import uuid1
|
||||
|
||||
from redis import Redis
|
||||
import requests
|
||||
|
||||
from django.core.cache import caches
|
||||
|
||||
from allianceauth import __title__ as AUTH_TITLE, __url__, __version__
|
||||
|
||||
from .. import __title__
|
||||
from .app_settings import (
|
||||
DISCORD_API_BASE_URL,
|
||||
DISCORD_API_TIMEOUT,
|
||||
DISCORD_DISABLE_ROLE_CREATION,
|
||||
DISCORD_GUILD_NAME_CACHE_MAX_AGE,
|
||||
DISCORD_OAUTH_BASE_URL,
|
||||
DISCORD_OAUTH_TOKEN_URL,
|
||||
DISCORD_ROLES_CACHE_MAX_AGE,
|
||||
)
|
||||
from .exceptions import DiscordRateLimitExhausted, DiscordTooManyRequestsError
|
||||
from ..utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
# max requests that can be executed until reset
|
||||
RATE_LIMIT_MAX_REQUESTS = 5
|
||||
|
||||
# Time until remaining requests are reset
|
||||
RATE_LIMIT_RESETS_AFTER = 5000
|
||||
|
||||
# Delay used for API backoff in case no info returned from API on 429s
|
||||
DEFAULT_BACKOFF_DELAY = 5000
|
||||
|
||||
# additional duration to compensate for potential clock discrepancies
|
||||
# with the Discord server
|
||||
DURATION_CONTINGENCY = 500
|
||||
|
||||
# Client will do a blocking wait rather than throwing a backoff exception if the
|
||||
# time until next reset is below this threshold
|
||||
WAIT_THRESHOLD = 250
|
||||
|
||||
# If the rate limit resets soon we will wait it out and then retry to
|
||||
# either get a remaining request from our cached counter
|
||||
# or again wait out a short reset time and retry again.
|
||||
# This could happen several times within a high concurrency situation,
|
||||
# but must fail after x tries to avoid an infinite loop
|
||||
RATE_LIMIT_RETRIES = 1000
|
||||
|
||||
|
||||
class DiscordClient:
|
||||
"""This class provides a web client for interacting with the Discord API
|
||||
|
||||
The client has rate limiting that supports concurrency.
|
||||
This means it is able to ensure the API rate limit is not violated,
|
||||
even when used concurrently, e.g. with multiple parallel celery tasks.
|
||||
|
||||
In addition the client support proper API backoff.
|
||||
|
||||
Synchronization of rate limit infos accross multiple processes
|
||||
is implemented with Redis and thus requires Redis as Django cache backend.
|
||||
|
||||
All durations are in milliseconds.
|
||||
"""
|
||||
OAUTH_BASE_URL = DISCORD_OAUTH_BASE_URL
|
||||
OAUTH_TOKEN_URL = DISCORD_OAUTH_TOKEN_URL
|
||||
|
||||
_KEY_GLOBAL_BACKOFF_UNTIL = 'DISCORD_GLOBAL_BACKOFF_UNTIL'
|
||||
_KEY_GLOBAL_RATE_LIMIT_REMAINING = 'DISCORD_GLOBAL_RATE_LIMIT_REMAINING'
|
||||
_KEYPREFIX_GUILD_NAME = 'DISCORD_GUILD_NAME'
|
||||
_KEYPREFIX_ROLE_NAME = 'DISCORD_ROLE_NAME'
|
||||
_ROLE_NAME_MAX_CHARS = 100
|
||||
_NICK_MAX_CHARS = 32
|
||||
|
||||
_HTTP_STATUS_CODE_NOT_FOUND = 404
|
||||
_HTTP_STATUS_CODE_RATE_LIMITED = 429
|
||||
_DISCORD_STATUS_CODE_UNKNOWN_MEMBER = 10007
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: str,
|
||||
redis: Redis = None,
|
||||
is_rate_limited: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Params:
|
||||
- access_token: Discord access token used to authenticate all calls to the API
|
||||
- redis: Redis instance to be used.
|
||||
- is_rate_limited: Set to False to run of rate limiting (use with care)
|
||||
If not specified will try to use the Redis instance
|
||||
from the default Django cache backend.
|
||||
"""
|
||||
self._access_token = str(access_token)
|
||||
self._is_rate_limited = bool(is_rate_limited)
|
||||
if not redis:
|
||||
default_cache = caches['default']
|
||||
self._redis = default_cache.get_master_client()
|
||||
if not isinstance(self._redis, Redis):
|
||||
raise RuntimeError(
|
||||
'This class requires a Redis client, but none was provided '
|
||||
'and the default Django cache backend is not Redis either.'
|
||||
)
|
||||
else:
|
||||
self._redis = redis
|
||||
|
||||
lua_1 = """
|
||||
if redis.call("exists", KEYS[1]) == 0 then
|
||||
redis.call("set", KEYS[1], ARGV[1], 'px', ARGV[2])
|
||||
end
|
||||
return redis.call("decr", KEYS[1])
|
||||
"""
|
||||
self.__redis_script_decr_or_set = self._redis.register_script(lua_1)
|
||||
|
||||
lua_2 = """
|
||||
local current_px = tonumber(redis.call("pttl", KEYS[1]))
|
||||
if current_px < tonumber(ARGV[2]) then
|
||||
return redis.call("set", KEYS[1], ARGV[1], 'px', ARGV[2])
|
||||
else
|
||||
return nil
|
||||
end
|
||||
"""
|
||||
self.__redis_script_set_longer = self._redis.register_script(lua_2)
|
||||
|
||||
@property
|
||||
def access_token(self):
|
||||
return self._access_token
|
||||
|
||||
@property
|
||||
def is_rate_limited(self):
|
||||
return self._is_rate_limited
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}(access_token=...{self.access_token[-5:]})'
|
||||
|
||||
def _redis_decr_or_set(self, name: str, value: str, px: int) -> bool:
|
||||
"""decreases the key value if it exists and returns the result
|
||||
else sets the key
|
||||
|
||||
Implemented as Lua script to ensure atomicity.
|
||||
"""
|
||||
return self.__redis_script_decr_or_set(
|
||||
keys=[str(name)], args=[str(value), int(px)]
|
||||
)
|
||||
|
||||
def _redis_set_if_longer(self, name: str, value: str, px: int) -> bool:
|
||||
"""like set, but only goes through if either key doesn't exist
|
||||
or px would be extended.
|
||||
|
||||
Implemented as Lua script to ensure atomicity.
|
||||
"""
|
||||
return self.__redis_script_set_longer(
|
||||
keys=[str(name)], args=[str(value), int(px)]
|
||||
)
|
||||
|
||||
# users
|
||||
|
||||
def current_user(self) -> dict:
|
||||
"""returns the user belonging to the current access_token"""
|
||||
authorization = f'Bearer {self.access_token}'
|
||||
r = self._api_request(
|
||||
method='get', route='users/@me', authorization=authorization
|
||||
)
|
||||
return r.json()
|
||||
|
||||
# guild roles
|
||||
|
||||
def create_guild_role(self, guild_id: int, role_name: str, **kwargs) -> dict:
|
||||
"""Create a new guild role with the given name.
|
||||
See official documentation for additional optional parameters.
|
||||
|
||||
Note that Discord allows creating multiple roles with the name name,
|
||||
so it's important to check existing roles before creating new one
|
||||
to avoid duplicates.
|
||||
|
||||
return a new role object on success
|
||||
"""
|
||||
route = f"guilds/{guild_id}/roles"
|
||||
data = {'name': self._sanitize_role_name(role_name)}
|
||||
data.update(kwargs)
|
||||
r = self._api_request(method='post', route=route, data=data)
|
||||
return r.json()
|
||||
|
||||
def guild_infos(self, guild_id: int) -> dict:
|
||||
"""Returns all basic infos about this guild"""
|
||||
route = f"guilds/{guild_id}"
|
||||
r = self._api_request(method='get', route=route)
|
||||
return r.json()
|
||||
|
||||
def guild_name(self, guild_id: int) -> str:
|
||||
"""returns the name of this guild (cached)
|
||||
or an empty string if something went wrong
|
||||
"""
|
||||
key_name = self._guild_name_cache_key(guild_id)
|
||||
guild_name = self._redis_decode(self._redis.get(key_name))
|
||||
if not guild_name:
|
||||
guild_infos = self.guild_infos(guild_id)
|
||||
if 'name' in guild_infos:
|
||||
guild_name = guild_infos['name']
|
||||
self._redis.set(
|
||||
name=key_name,
|
||||
value=guild_name,
|
||||
px=DISCORD_GUILD_NAME_CACHE_MAX_AGE
|
||||
)
|
||||
else:
|
||||
guild_name = ''
|
||||
|
||||
return guild_name
|
||||
|
||||
@classmethod
|
||||
def _guild_name_cache_key(cls, guild_id: int) -> str:
|
||||
"""Returns key for accessing role given by name in the role cache"""
|
||||
gen_key = DiscordClient._generate_hash(f'{guild_id}')
|
||||
return f'{cls._KEYPREFIX_GUILD_NAME}__{gen_key}'
|
||||
|
||||
def guild_roles(self, guild_id: int) -> list:
|
||||
"""Returns the list of all roles for this guild"""
|
||||
route = f"guilds/{guild_id}/roles"
|
||||
r = self._api_request(method='get', route=route)
|
||||
return r.json()
|
||||
|
||||
def delete_guild_role(self, guild_id: int, role_id: int) -> bool:
|
||||
"""Deletes a guild role"""
|
||||
route = f"guilds/{guild_id}/roles/{role_id}"
|
||||
r = self._api_request(method='delete', route=route)
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# guild role cache
|
||||
|
||||
def match_guild_roles_to_names(self, guild_id: int, role_names: list) -> list:
|
||||
"""returns Discord roles matching the given names
|
||||
|
||||
Returns as list of tuple of role and created flag
|
||||
|
||||
Will try to match with existing roles names
|
||||
Non-existing roles will be created, then created flag will be True
|
||||
Roles names are cached to improve performance
|
||||
"""
|
||||
roles = list()
|
||||
for role_name in role_names:
|
||||
role, created = self.match_guild_role_to_name(
|
||||
guild_id=guild_id, role_name=self._sanitize_role_name(role_name)
|
||||
)
|
||||
if role:
|
||||
roles.append((role, created))
|
||||
return roles
|
||||
|
||||
def match_guild_role_to_name(self, guild_id: int, role_name: str) -> tuple:
|
||||
"""returns Discord role matching the given name
|
||||
|
||||
Returns as tuple of role and created flag
|
||||
|
||||
Will try to match with existing roles names
|
||||
Non-existing roles will be created, then created flag will be True
|
||||
Roles names are cached to improve performance
|
||||
"""
|
||||
created = False
|
||||
role_name = self._sanitize_role_name(role_name)
|
||||
role_id = self._redis_decode(
|
||||
self._redis.get(name=self._role_cache_key(guild_id, role_name))
|
||||
)
|
||||
if not role_id:
|
||||
role_id = None
|
||||
for role in self.guild_roles(guild_id):
|
||||
self._update_role_cache(guild_id, role)
|
||||
if role['name'] == role_name:
|
||||
role_id = role['id']
|
||||
|
||||
if role_id:
|
||||
role = self._create_role(role_id, role_name)
|
||||
|
||||
else:
|
||||
if not DISCORD_DISABLE_ROLE_CREATION:
|
||||
role_raw = self.create_guild_role(guild_id, role_name)
|
||||
role = self._create_role(role_raw['id'], role_name)
|
||||
self._update_role_cache(guild_id, role)
|
||||
created = True
|
||||
else:
|
||||
role = None
|
||||
else:
|
||||
role = self._create_role(int(role_id), role_name)
|
||||
|
||||
return role, created
|
||||
|
||||
@staticmethod
|
||||
def _create_role(role_id: int, role_name: str) -> dict:
|
||||
return {'id': int(role_id), 'name': str(role_name)}
|
||||
|
||||
def _update_role_cache(self, guild_id: int, role: dict) -> bool:
|
||||
"""updates role cache with given role
|
||||
|
||||
Returns True on success, else False or raises exception
|
||||
"""
|
||||
if not isinstance(role, dict):
|
||||
raise TypeError('role must be a dict')
|
||||
|
||||
return self._redis.set(
|
||||
name=self._role_cache_key(guild_id=guild_id, role_name=role['name']),
|
||||
value=role['id'],
|
||||
px=DISCORD_ROLES_CACHE_MAX_AGE
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _role_cache_key(cls, guild_id: int, role_name: str) -> str:
|
||||
"""Returns key for accessing role given by name in the role cache"""
|
||||
gen_key = DiscordClient._generate_hash(f'{guild_id}{role_name}')
|
||||
return f'{cls._KEYPREFIX_ROLE_NAME}__{gen_key}'
|
||||
|
||||
# guild members
|
||||
|
||||
def add_guild_member(
|
||||
self,
|
||||
guild_id: int,
|
||||
user_id: int,
|
||||
access_token: str,
|
||||
role_ids: list = None,
|
||||
nick: str = None
|
||||
) -> bool:
|
||||
"""Adds a user to the guilds.
|
||||
|
||||
Returns:
|
||||
- True when a new user was added
|
||||
- None if the user already existed
|
||||
- False when something went wrong or raises exception
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
data = {
|
||||
'access_token': str(access_token)
|
||||
}
|
||||
if role_ids:
|
||||
data['roles'] = self._sanitize_role_ids(role_ids)
|
||||
|
||||
if nick:
|
||||
data['nick'] = str(nick)[:self._NICK_MAX_CHARS]
|
||||
|
||||
r = self._api_request(method='put', route=route, data=data)
|
||||
r.raise_for_status()
|
||||
if r.status_code == 201:
|
||||
return True
|
||||
elif r.status_code == 204:
|
||||
return None
|
||||
else:
|
||||
return False
|
||||
|
||||
def guild_member(self, guild_id: int, user_id: int) -> dict:
|
||||
"""returns the user info for a guild member
|
||||
|
||||
or None if the user is not a member of the guild
|
||||
"""
|
||||
route = f'guilds/{guild_id}/members/{user_id}'
|
||||
r = self._api_request(method='get', route=route, raise_for_status=False)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning("Discord user ID %s could not be found on server.", user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def modify_guild_member(
|
||||
self, guild_id: int, user_id: int, role_ids: list = None, nick: str = None
|
||||
) -> bool:
|
||||
"""Modify attributes of a guild member.
|
||||
|
||||
Returns
|
||||
- True when successful
|
||||
- None if user is not a member of this guild
|
||||
- False otherwise
|
||||
"""
|
||||
if not role_ids and not nick:
|
||||
raise ValueError('Must specify role_ids or nick')
|
||||
|
||||
if role_ids and not isinstance(role_ids, list):
|
||||
raise TypeError('role_ids must be a list type')
|
||||
|
||||
data = dict()
|
||||
if role_ids:
|
||||
data['roles'] = self._sanitize_role_ids(role_ids)
|
||||
|
||||
if nick:
|
||||
data['nick'] = self._sanitize_nick(nick)
|
||||
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
r = self._api_request(
|
||||
method='patch', route=route, data=data, raise_for_status=False
|
||||
)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def remove_guild_member(self, guild_id: int, user_id: int) -> bool:
|
||||
"""Remove a member from a guild
|
||||
|
||||
Returns:
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}"
|
||||
r = self._api_request(
|
||||
method='delete', route=route, raise_for_status=False
|
||||
)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# Guild member roles
|
||||
|
||||
def add_guild_member_role(
|
||||
self, guild_id: int, user_id: int, role_id: int
|
||||
) -> bool:
|
||||
"""Adds a role to a guild member
|
||||
|
||||
Returns:
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}/roles/{role_id}"
|
||||
r = self._api_request(method='put', route=route, raise_for_status=False)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def remove_guild_member_role(
|
||||
self, guild_id: int, user_id: int, role_id: int
|
||||
) -> bool:
|
||||
"""Removes a role to a guild member
|
||||
|
||||
Returns:
|
||||
- True when successful
|
||||
- None if member does not exist
|
||||
- False otherwise
|
||||
"""
|
||||
route = f"guilds/{guild_id}/members/{user_id}/roles/{role_id}"
|
||||
r = self._api_request(method='delete', route=route, raise_for_status=False)
|
||||
if self._is_member_unknown_error(r):
|
||||
logger.warning('User ID %s is not a member of this guild', user_id)
|
||||
return None
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
if r.status_code == 204:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _is_member_unknown_error(cls, r: requests.Response) -> bool:
|
||||
try:
|
||||
result = (
|
||||
r.status_code == cls._HTTP_STATUS_CODE_NOT_FOUND
|
||||
and r.json()['code'] == cls._DISCORD_STATUS_CODE_UNKNOWN_MEMBER
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
result = False
|
||||
|
||||
return result
|
||||
|
||||
# Internal methods
|
||||
|
||||
def _api_request(
|
||||
self,
|
||||
method: str,
|
||||
route: str,
|
||||
data: dict = None,
|
||||
authorization: str = None,
|
||||
raise_for_status: bool = True
|
||||
) -> requests.Response:
|
||||
"""Core method for performing all API calls"""
|
||||
uid = uuid1().hex
|
||||
|
||||
if not hasattr(requests, method):
|
||||
raise ValueError('Invalid method: %s' % method)
|
||||
|
||||
if not authorization:
|
||||
authorization = f'Bot {self.access_token}'
|
||||
|
||||
self._handle_ongoing_api_backoff(uid)
|
||||
if self.is_rate_limited:
|
||||
self._ensure_rate_limed_not_exhausted(uid)
|
||||
headers = {
|
||||
'User-Agent': f'{AUTH_TITLE} ({__url__}, {__version__})',
|
||||
'accept': 'application/json',
|
||||
'X-RateLimit-Precision': 'millisecond',
|
||||
'authorization': str(authorization)
|
||||
}
|
||||
if data:
|
||||
headers['content-type'] = 'application/json'
|
||||
|
||||
url = urljoin(DISCORD_API_BASE_URL, route)
|
||||
args = {
|
||||
'url': url,
|
||||
'headers': headers,
|
||||
'timeout': DISCORD_API_TIMEOUT / 1000
|
||||
}
|
||||
if data:
|
||||
args['json'] = data
|
||||
|
||||
logger.info('%s: sending %s request to url \'%s\'', uid, method.upper(), url)
|
||||
logger.debug('%s: request headers:\n%s', uid, headers)
|
||||
r = getattr(requests, method)(**args)
|
||||
logger.debug(
|
||||
'%s: returned status code %d with headers:\n%s',
|
||||
uid,
|
||||
r.status_code,
|
||||
r.headers
|
||||
)
|
||||
logger.debug('%s: response:\n%s', uid, r.text)
|
||||
if not r.ok:
|
||||
logger.warning(
|
||||
'%s: Discord API returned error code %d and this response: %s',
|
||||
uid,
|
||||
r.status_code,
|
||||
r.text
|
||||
)
|
||||
|
||||
if r.status_code == self._HTTP_STATUS_CODE_RATE_LIMITED:
|
||||
self._handle_new_api_backoff(r, uid)
|
||||
|
||||
self._report_rate_limit_from_api(r, uid)
|
||||
|
||||
if raise_for_status:
|
||||
r.raise_for_status()
|
||||
|
||||
return r
|
||||
|
||||
def _handle_ongoing_api_backoff(self, uid: str) -> None:
|
||||
"""checks if api is currently on backoff
|
||||
if on backoff: will do a blocking wait if it expires soon,
|
||||
else raises exception
|
||||
"""
|
||||
global_backoff_duration = self._redis.pttl(self._KEY_GLOBAL_BACKOFF_UNTIL)
|
||||
if global_backoff_duration > 0:
|
||||
if global_backoff_duration < WAIT_THRESHOLD:
|
||||
logger.info(
|
||||
'%s: Global API backoff still ongoing for %s ms. Waiting.',
|
||||
uid,
|
||||
global_backoff_duration
|
||||
)
|
||||
sleep(global_backoff_duration / 1000)
|
||||
else:
|
||||
logger.info(
|
||||
'%s: Global API backoff still ongoing for %s ms. Re-raising.',
|
||||
uid,
|
||||
global_backoff_duration
|
||||
)
|
||||
raise DiscordTooManyRequestsError(retry_after=global_backoff_duration)
|
||||
|
||||
def _ensure_rate_limed_not_exhausted(self, uid: str) -> int:
|
||||
"""ensures that the rate limit is not exhausted
|
||||
if exhausted: will do a blocking wait if rate limit resets soon,
|
||||
else raises exception
|
||||
|
||||
returns requests remaining on success
|
||||
"""
|
||||
for _ in range(RATE_LIMIT_RETRIES):
|
||||
requests_remaining = self._redis_decr_or_set(
|
||||
name=self._KEY_GLOBAL_RATE_LIMIT_REMAINING,
|
||||
value=RATE_LIMIT_MAX_REQUESTS,
|
||||
px=RATE_LIMIT_RESETS_AFTER + DURATION_CONTINGENCY
|
||||
)
|
||||
resets_in = self._redis.pttl(self._KEY_GLOBAL_RATE_LIMIT_REMAINING)
|
||||
if requests_remaining >= 0:
|
||||
logger.debug(
|
||||
'%s: Got %d remaining requests until reset in %s ms',
|
||||
uid,
|
||||
requests_remaining + 1,
|
||||
resets_in
|
||||
)
|
||||
return requests_remaining
|
||||
|
||||
elif resets_in < WAIT_THRESHOLD:
|
||||
sleep(resets_in / 1000)
|
||||
logger.debug(
|
||||
'%s: No requests remaining until reset in %d ms. '
|
||||
'Waiting for reset.',
|
||||
uid,
|
||||
resets_in
|
||||
)
|
||||
continue
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
'%s: No requests remaining until reset in %d ms. '
|
||||
'Raising exception.',
|
||||
uid,
|
||||
resets_in
|
||||
)
|
||||
raise DiscordRateLimitExhausted(resets_in)
|
||||
|
||||
raise RuntimeError('Failed to handle rate limit after after too tries.')
|
||||
|
||||
def _handle_new_api_backoff(self, r: requests.Response, uid: str) -> None:
|
||||
"""raises exception for new API backoff error"""
|
||||
response = r.json()
|
||||
if 'retry_after' in response:
|
||||
try:
|
||||
retry_after = \
|
||||
int(response['retry_after']) + DURATION_CONTINGENCY
|
||||
except ValueError:
|
||||
retry_after = DEFAULT_BACKOFF_DELAY
|
||||
else:
|
||||
retry_after = DEFAULT_BACKOFF_DELAY
|
||||
self._redis_set_if_longer(
|
||||
name=self._KEY_GLOBAL_BACKOFF_UNTIL,
|
||||
value='GLOBAL_API_BACKOFF',
|
||||
px=retry_after
|
||||
)
|
||||
logger.warning(
|
||||
"%s: Rate limit violated. Need to back off for at least %d ms",
|
||||
uid,
|
||||
retry_after
|
||||
)
|
||||
raise DiscordTooManyRequestsError(retry_after=retry_after)
|
||||
|
||||
def _report_rate_limit_from_api(self, r, uid):
|
||||
"""Tries to log the current rate limit reported from API"""
|
||||
if (
|
||||
logger.getEffectiveLevel() <= logging.DEBUG
|
||||
and 'x-ratelimit-limit' in r.headers
|
||||
and 'x-ratelimit-remaining' in r.headers
|
||||
and 'x-ratelimit-reset-after' in r.headers
|
||||
):
|
||||
try:
|
||||
limit = int(r.headers['x-ratelimit-limit'])
|
||||
remaining = int(r.headers['x-ratelimit-remaining'])
|
||||
reset_after = float(r.headers['x-ratelimit-reset-after']) * 1000
|
||||
if remaining + 1 == limit:
|
||||
logger.debug(
|
||||
'%s: Rate limit reported from API: %d requests per %s ms',
|
||||
uid,
|
||||
limit,
|
||||
reset_after
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _redis_decode(value: str) -> str:
|
||||
"""Decodes a string from Redis and passes through None and Booleans"""
|
||||
if value is not None and not isinstance(value, bool):
|
||||
return value.decode('utf-8')
|
||||
else:
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _generate_hash(key: str) -> str:
|
||||
return md5(key.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_role_ids(role_ids: list) -> list:
|
||||
"""make sure its a list of integers"""
|
||||
return [int(role_id) for role_id in list(role_ids)]
|
||||
|
||||
@classmethod
|
||||
def _sanitize_role_name(cls, role_name: str) -> str:
|
||||
"""shortens too long strings if necessary"""
|
||||
return str(role_name)[:cls._ROLE_NAME_MAX_CHARS]
|
||||
|
||||
@classmethod
|
||||
def _sanitize_nick(cls, nick: str) -> str:
|
||||
"""shortens too long strings if necessary"""
|
||||
return str(nick)[:cls._NICK_MAX_CHARS]
|
||||
@@ -0,0 +1,33 @@
|
||||
import math
|
||||
|
||||
|
||||
class DiscordClientException(Exception):
|
||||
"""Base Exception for the Discord client"""
|
||||
|
||||
|
||||
class DiscordApiBackoff(DiscordClientException):
|
||||
"""Exception signaling we need to backoff from sending requests to the API for now
|
||||
"""
|
||||
|
||||
def __init__(self, retry_after: int):
|
||||
"""
|
||||
:param retry_after: int time to retry after in milliseconds
|
||||
"""
|
||||
super().__init__()
|
||||
self.retry_after = int(retry_after)
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self):
|
||||
return math.ceil(self.retry_after / 1000)
|
||||
|
||||
|
||||
class DiscordRateLimitExhausted(DiscordApiBackoff):
|
||||
"""Exception signaling that the total number of requests allowed under the
|
||||
current rate limit have been exhausted and weed to wait until next reset.
|
||||
"""
|
||||
|
||||
|
||||
class DiscordTooManyRequestsError(DiscordApiBackoff):
|
||||
"""API has responded with a 429 Too Many Requests Error.
|
||||
Need to backoff for now.
|
||||
"""
|
||||
@@ -0,0 +1,85 @@
|
||||
"""This is script is for concurrency testing the Discord client with a Discord server.
|
||||
|
||||
It will run multiple requests against Discord with multiple workers in parallel.
|
||||
The results can be analysed in a special log file.
|
||||
|
||||
This script is design to be run manually as unit test, e.g. by running the following:
|
||||
|
||||
python manage.py test
|
||||
allianceauth.services.modules.discord.discord_client.tests.piloting_concurrency
|
||||
|
||||
To make it work please set the below mentioned environment variables for your server.
|
||||
Since this may cause lots of 429s we'd recommend NOT to use your
|
||||
alliance Discord server for this.
|
||||
"""
|
||||
|
||||
import os
|
||||
from random import random
|
||||
import threading
|
||||
from time import sleep
|
||||
from django.test import TestCase
|
||||
|
||||
from .. import DiscordClient, DiscordApiBackoff
|
||||
|
||||
from ...utils import set_logger_to_file
|
||||
|
||||
logger = set_logger_to_file(
|
||||
'allianceauth.services.modules.discord.discord_client.client', __file__
|
||||
)
|
||||
|
||||
# Make sure to set these environnement variables for your Discord server and user
|
||||
DISCORD_GUILD_ID = os.environ['DISCORD_GUILD_ID']
|
||||
DISCORD_BOT_TOKEN = os.environ['DISCORD_BOT_TOKEN']
|
||||
DISCORD_USER_ID = os.environ['DISCORD_USER_ID']
|
||||
NICK = 'Dummy'
|
||||
|
||||
# Configure these settings to adjust the load profile
|
||||
NUMBER_OF_WORKERS = 5
|
||||
NUMBER_OF_RUNS = 10
|
||||
|
||||
# max seconds a worker waits before starting a new run
|
||||
# set to near 0 for max load preassure
|
||||
MAX_JITTER_PER_RUN_SECS = 1.0
|
||||
|
||||
|
||||
def worker(num: int):
|
||||
"""worker function"""
|
||||
worker_info = 'worker %d' % num
|
||||
logger.info('%s: started', worker_info)
|
||||
client = DiscordClient(DISCORD_BOT_TOKEN)
|
||||
try:
|
||||
runs = 0
|
||||
while runs < NUMBER_OF_RUNS:
|
||||
run_info = '%s: run %d' % (worker_info, runs + 1)
|
||||
my_jitter_secs = random() * MAX_JITTER_PER_RUN_SECS
|
||||
logger.info('%s - waiting %s secs', run_info, f'{my_jitter_secs:.3f}')
|
||||
sleep(my_jitter_secs)
|
||||
logger.info('%s - started', run_info)
|
||||
try:
|
||||
client.modify_guild_member(
|
||||
DISCORD_GUILD_ID, DISCORD_USER_ID, nick=NICK
|
||||
)
|
||||
runs += 1
|
||||
except DiscordApiBackoff as bo:
|
||||
message = '%s - waiting out API backoff for %d ms' % (
|
||||
run_info, bo.retry_after
|
||||
)
|
||||
logger.info(message)
|
||||
print()
|
||||
print(message)
|
||||
sleep(bo.retry_after / 1000)
|
||||
|
||||
except Exception as ex:
|
||||
logger.exception('%s: Processing aborted: %s', worker_info, ex)
|
||||
|
||||
logger.info('%s: finished', worker_info)
|
||||
return
|
||||
|
||||
|
||||
class TestMulti(TestCase):
|
||||
|
||||
def test_multi(self):
|
||||
logger.info('Starting multi test')
|
||||
for num in range(NUMBER_OF_WORKERS):
|
||||
x = threading.Thread(target=worker, args=(num + 1,))
|
||||
x.start()
|
||||
@@ -0,0 +1,130 @@
|
||||
"""This script is for functional testing of the Discord client with a Discord server
|
||||
|
||||
It will run single requests of the various functions to validate
|
||||
that they actually work - excluding those that require Oauth, or does not work
|
||||
with a bot token. The results can be also seen in a special log file.
|
||||
|
||||
This script is design to be run manually as unit test, e.g. by running the following:
|
||||
|
||||
python manage.py test
|
||||
allianceauth.services.modules.discord.discord_self.client.tests.piloting_functionality
|
||||
|
||||
To make it work please set the below mentioned environment variables for your server.
|
||||
Since this may cause lots of 429s we'd recommend NOT to use your
|
||||
alliance Discord server for this.
|
||||
"""
|
||||
|
||||
from uuid import uuid1
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from time import sleep
|
||||
|
||||
from .. import DiscordClient
|
||||
from ...utils import set_logger_to_file
|
||||
|
||||
logger = set_logger_to_file(
|
||||
'allianceauth.services.modules.discord.discord_self.client.client', __file__
|
||||
)
|
||||
|
||||
# Make sure to set these environnement variables for your Discord server and user
|
||||
DISCORD_GUILD_ID = os.environ['DISCORD_GUILD_ID']
|
||||
DISCORD_BOT_TOKEN = os.environ['DISCORD_BOT_TOKEN']
|
||||
DISCORD_USER_ID = os.environ['DISCORD_USER_ID']
|
||||
|
||||
RATE_LIMIT_DELAY_SECS = 1
|
||||
|
||||
|
||||
class TestDiscordApiLive(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
logger.info('Live demo of the Discord API Client')
|
||||
cls.client = DiscordClient(DISCORD_BOT_TOKEN)
|
||||
|
||||
def test_run_other_features(self):
|
||||
"""runs features that have not been run in any of the other tests"""
|
||||
self.client.guild_infos(DISCORD_GUILD_ID)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
self.client.guild_name(DISCORD_GUILD_ID)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
self.client.match_guild_role_to_name(DISCORD_GUILD_ID, 'Testrole')
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
self.client.match_guild_roles_to_names(
|
||||
DISCORD_GUILD_ID, ['Testrole A', 'Testrole B']
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
def test_create_and_remove_roles(self):
|
||||
# get base
|
||||
logger.info('guild_roles')
|
||||
expected = {role['id'] for role in self.client.guild_roles(DISCORD_GUILD_ID)}
|
||||
|
||||
# add role
|
||||
role_name = 'my test role 12345678'
|
||||
logger.info('create_guild_role')
|
||||
new_role = self.client.create_guild_role(
|
||||
guild_id=DISCORD_GUILD_ID, role_name=role_name
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
self.assertEqual(new_role['name'], role_name)
|
||||
|
||||
# remove role again
|
||||
logger.info('delete_guild_role')
|
||||
self.client.delete_guild_role(
|
||||
guild_id=DISCORD_GUILD_ID, role_id=new_role['id']
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
# verify it worked
|
||||
logger.info('guild_roles')
|
||||
role_ids = {role['id'] for role in self.client.guild_roles(DISCORD_GUILD_ID)}
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
self.assertSetEqual(role_ids, expected)
|
||||
|
||||
def test_change_member_nick(self):
|
||||
# set new nick for user
|
||||
logger.info('modify_guild_member')
|
||||
new_nick = f'Testnick {uuid1().hex}'[:32]
|
||||
self.assertTrue(
|
||||
self.client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=DISCORD_USER_ID, nick=new_nick
|
||||
)
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
# verify it is saved
|
||||
logger.info('guild_member')
|
||||
user = self.client.guild_member(DISCORD_GUILD_ID, DISCORD_USER_ID)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
self.assertEqual(user['nick'], new_nick)
|
||||
|
||||
def test_member_add_remove_roles(self):
|
||||
# create new guild role
|
||||
logger.info('create_guild_role')
|
||||
new_role = self.client.create_guild_role(
|
||||
guild_id=DISCORD_GUILD_ID, role_name='Special role 98765'
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
new_role_id = new_role['id']
|
||||
|
||||
# add to member
|
||||
logger.info('add_guild_member_role')
|
||||
self.assertTrue(
|
||||
self.client.add_guild_member_role(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=DISCORD_USER_ID, role_id=new_role_id
|
||||
)
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
|
||||
# remove again
|
||||
logger.info('remove_guild_member_role')
|
||||
self.assertTrue(
|
||||
self.client.remove_guild_member_role(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=DISCORD_USER_ID, role_id=new_role_id
|
||||
)
|
||||
)
|
||||
sleep(RATE_LIMIT_DELAY_SECS)
|
||||
47
allianceauth/services/modules/discord/discord_client/tests/piloting_tasks.py
Executable file
47
allianceauth/services/modules/discord/discord_client/tests/piloting_tasks.py
Executable file
@@ -0,0 +1,47 @@
|
||||
"""Load testing Discord services tasks
|
||||
|
||||
This script will load test the Discord service tasks.
|
||||
Note that his will run against your production Auth.
|
||||
To run this test start a bunch of celery workers and then run this script directly.
|
||||
|
||||
This script requires a user with a Discord account setup through Auth.
|
||||
Please provide the respective Discord user ID by setting it as environment variable:
|
||||
|
||||
export DISCORD_USER_ID="123456789"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
myauth_dir = '/home/erik997/dev/python/aa/allianceauth-dev/myauth'
|
||||
sys.path.insert(0, myauth_dir)
|
||||
|
||||
import django # noqa: E402
|
||||
|
||||
# init and setup django project
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "myauth.settings.local")
|
||||
django.setup()
|
||||
|
||||
from uuid import uuid1 # noqa: E402
|
||||
|
||||
from django.contrib.auth.models import User # noqa: E402
|
||||
# from allianceauth.services.modules.discord.tasks import update_groups # noqa: E402
|
||||
|
||||
if 'DISCORD_USER_ID' not in os.environ:
|
||||
print('Please set DISCORD_USER_ID')
|
||||
exit()
|
||||
|
||||
DISCORD_USER_ID = os.environ['DISCORD_USER_ID']
|
||||
|
||||
|
||||
def run_many_updates(runs):
|
||||
user = User.objects.get(discord__uid=DISCORD_USER_ID)
|
||||
for _ in range(runs):
|
||||
new_nick = f'Testnick {uuid1().hex}'[:32]
|
||||
user.profile.main_character.character_name = new_nick
|
||||
user.profile.main_character.save()
|
||||
# update_groups.delay(user_pk=user.pk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_many_updates(20)
|
||||
@@ -0,0 +1,26 @@
|
||||
# Discord rate limits
|
||||
|
||||
The following table shows the rate limit as reported from the API for different routes.
|
||||
|
||||
method | limit | reset | rate / s | bucket
|
||||
-- | -- | -- | -- | --
|
||||
add_guild_member | 10 | 10,000 | 1 | self
|
||||
create_guild_role | 250 | 180,000,000 | 0.001 | self
|
||||
delete_guild_role | g | g | g | g
|
||||
guild_member | 5 | 1,000 | 5 | self
|
||||
guild_roles | g | g | g | g
|
||||
add_guild_member_role | 10 | 10,000 | 1 | B1
|
||||
remove_guild_member_role | 10 | 10,000 | 1 | B1
|
||||
modify_guild_member | 10 | 10,000 | 1 | self
|
||||
remove_guild_member | 5 | 1,000 | 5 | self
|
||||
current_user | g | g | g | g
|
||||
|
||||
Legend:
|
||||
|
||||
- g: global rate limit. API does not provide any rate limit infos for those routes.
|
||||
|
||||
- reset: Values in milliseconds.
|
||||
|
||||
- bucket: "self" means the rate limit is only counted for that route, Bx means the same rate limit is counted for multiple routes.
|
||||
|
||||
- Data was collected on 2020-MAY-07 and is subject to change.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,33 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from ..exceptions import (
|
||||
DiscordApiBackoff,
|
||||
DiscordClientException,
|
||||
DiscordRateLimitExhausted,
|
||||
DiscordTooManyRequestsError
|
||||
)
|
||||
|
||||
|
||||
class TestExceptions(TestCase):
|
||||
|
||||
def test_DiscordApiException(self):
|
||||
with self.assertRaises(DiscordClientException):
|
||||
raise DiscordClientException()
|
||||
|
||||
def test_DiscordApiBackoff_raise(self):
|
||||
with self.assertRaises(DiscordApiBackoff):
|
||||
raise DiscordApiBackoff(999)
|
||||
|
||||
def test_DiscordApiBackoff_retry_after_seconds(self):
|
||||
retry_after = 999
|
||||
ex = DiscordApiBackoff(retry_after)
|
||||
self.assertEqual(ex.retry_after, retry_after)
|
||||
self.assertEqual(ex.retry_after_seconds, 1)
|
||||
|
||||
def test_DiscordRateLimitedExhausted_raise(self):
|
||||
with self.assertRaises(DiscordRateLimitExhausted):
|
||||
raise DiscordRateLimitExhausted(999)
|
||||
|
||||
def test_DiscordApiBackoffError_raise(self):
|
||||
with self.assertRaises(DiscordTooManyRequestsError):
|
||||
raise DiscordTooManyRequestsError(999)
|
||||
@@ -1,333 +0,0 @@
|
||||
import requests
|
||||
import math
|
||||
from django.conf import settings
|
||||
from requests_oauthlib import OAuth2Session
|
||||
from functools import wraps
|
||||
import logging
|
||||
import datetime
|
||||
import time
|
||||
from django.core.cache import cache
|
||||
from hashlib import md5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DISCORD_URL = "https://discordapp.com/api"
|
||||
|
||||
AUTH_URL = "https://discordapp.com/api/oauth2/authorize"
|
||||
TOKEN_URL = "https://discordapp.com/api/oauth2/token"
|
||||
|
||||
"""
|
||||
Previously all we asked for was permission to kick members, manage roles, and manage nicknames.
|
||||
Users have reported weird unauthorized errors we don't understand. So now we ask for full server admin.
|
||||
It's almost fixed the problem.
|
||||
"""
|
||||
# kick members, manage roles, manage nicknames, create instant invite
|
||||
# BOT_PERMISSIONS = 0x00000002 + 0x10000000 + 0x08000000 + 0x00000001
|
||||
BOT_PERMISSIONS = 0x00000008
|
||||
|
||||
# get user ID, accept invite
|
||||
SCOPES = [
|
||||
'identify',
|
||||
'guilds.join',
|
||||
]
|
||||
|
||||
GROUP_CACHE_MAX_AGE = getattr(settings, 'DISCORD_GROUP_CACHE_MAX_AGE', 2 * 60 * 60) # 2 hours default
|
||||
|
||||
|
||||
class DiscordApiException(Exception):
|
||||
def __init__(self):
|
||||
super(Exception, self).__init__()
|
||||
|
||||
|
||||
class DiscordApiTooBusy(DiscordApiException):
|
||||
def __init__(self):
|
||||
super(DiscordApiException, self).__init__()
|
||||
self.message = "The Discord API is too busy to process this request now, please try again later."
|
||||
|
||||
|
||||
class DiscordApiBackoff(DiscordApiException):
|
||||
def __init__(self, retry_after, global_ratelimit):
|
||||
"""
|
||||
:param retry_after: int time to retry after in milliseconds
|
||||
:param global_ratelimit: bool Is the API under a global backoff
|
||||
"""
|
||||
super(DiscordApiException, self).__init__()
|
||||
self.retry_after = retry_after
|
||||
self.global_ratelimit = global_ratelimit
|
||||
|
||||
@property
|
||||
def retry_after_seconds(self):
|
||||
return math.ceil(self.retry_after / 1000)
|
||||
|
||||
|
||||
cache_time_format = '%Y-%m-%d %H:%M:%S.%f'
|
||||
|
||||
|
||||
def api_backoff(func):
|
||||
"""
|
||||
Decorator, Handles HTTP 429 "Too Many Requests" messages from the Discord API
|
||||
If blocking=True is specified, this function will block and retry
|
||||
the function up to max_retries=n times, or 3 if retries is not specified.
|
||||
If the API call still recieves a backoff timer this function will raise
|
||||
a <DiscordApiTooBusy> exception.
|
||||
If the caller chooses blocking=False, the decorator will raise a DiscordApiBackoff
|
||||
exception and the caller can choose to retry after the given timespan available in
|
||||
the retry_after property in seconds.
|
||||
"""
|
||||
|
||||
class PerformBackoff(Exception):
|
||||
def __init__(self, retry_after, retry_datetime, global_ratelimit):
|
||||
super(Exception, self).__init__()
|
||||
self.retry_after = int(retry_after)
|
||||
self.retry_datetime = retry_datetime
|
||||
self.global_ratelimit = global_ratelimit
|
||||
|
||||
@wraps(func)
|
||||
def decorated(*args, **kwargs):
|
||||
blocking = kwargs.get('blocking', False)
|
||||
retries = kwargs.get('max_retries', 3)
|
||||
|
||||
# Strip our parameters
|
||||
if 'max_retries' in kwargs:
|
||||
del kwargs['max_retries']
|
||||
if 'blocking' in kwargs:
|
||||
del kwargs['blocking']
|
||||
|
||||
cache_key = 'DISCORD_BACKOFF_' + func.__name__
|
||||
cache_global_key = 'DISCORD_BACKOFF_GLOBAL'
|
||||
|
||||
while retries > 0:
|
||||
try:
|
||||
try:
|
||||
# Check global backoff first, then route backoff
|
||||
existing_global_backoff = cache.get(cache_global_key)
|
||||
existing_backoff = existing_global_backoff or cache.get(cache_key)
|
||||
if existing_backoff:
|
||||
backoff_timer = datetime.datetime.strptime(existing_backoff, cache_time_format)
|
||||
if backoff_timer > datetime.datetime.utcnow():
|
||||
backoff_seconds = (backoff_timer - datetime.datetime.utcnow()).total_seconds()
|
||||
logger.debug("Still under backoff for %s seconds, backing off" % backoff_seconds)
|
||||
# Still under backoff
|
||||
raise PerformBackoff(
|
||||
retry_after=backoff_seconds,
|
||||
retry_datetime=backoff_timer,
|
||||
global_ratelimit=bool(existing_global_backoff)
|
||||
)
|
||||
logger.debug("Calling API calling function")
|
||||
return func(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
if e.response.status_code == 429:
|
||||
try:
|
||||
retry_after = int(e.response.headers['Retry-After'])
|
||||
except (TypeError, KeyError):
|
||||
# Pick some random time
|
||||
retry_after = 5000
|
||||
|
||||
logger.info("Received backoff from API of %s seconds, handling" % retry_after)
|
||||
# Store value in redis
|
||||
backoff_until = (datetime.datetime.utcnow() +
|
||||
datetime.timedelta(milliseconds=retry_after))
|
||||
global_backoff = bool(e.response.headers.get('X-RateLimit-Global', False))
|
||||
if global_backoff:
|
||||
logger.info("Global backoff!!")
|
||||
cache.set(cache_global_key, backoff_until.strftime(cache_time_format), retry_after)
|
||||
else:
|
||||
cache.set(cache_key, backoff_until.strftime(cache_time_format), retry_after)
|
||||
raise PerformBackoff(retry_after=retry_after, retry_datetime=backoff_until,
|
||||
global_ratelimit=global_backoff)
|
||||
else:
|
||||
# Not 429, re-raise
|
||||
raise e
|
||||
except PerformBackoff as bo:
|
||||
# Sleep if we're blocking
|
||||
if blocking:
|
||||
logger.info("Blocking Back off from API calls for %s seconds" % bo.retry_after)
|
||||
time.sleep((10 if bo.retry_after > 10 else bo.retry_after) / 1000)
|
||||
else:
|
||||
# Otherwise raise exception and let caller handle the backoff
|
||||
raise DiscordApiBackoff(retry_after=bo.retry_after, global_ratelimit=bo.global_ratelimit)
|
||||
finally:
|
||||
retries -= 1
|
||||
if retries == 0:
|
||||
raise DiscordApiTooBusy()
|
||||
return decorated
|
||||
|
||||
|
||||
class DiscordOAuthManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_name(name):
|
||||
return name[:32]
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_group_name(name):
|
||||
return name[:100]
|
||||
|
||||
@staticmethod
|
||||
def generate_bot_add_url():
|
||||
return AUTH_URL + '?client_id=' + settings.DISCORD_APP_ID + '&scope=bot&permissions=' + str(BOT_PERMISSIONS)
|
||||
|
||||
@staticmethod
|
||||
def generate_oauth_redirect_url():
|
||||
oauth = OAuth2Session(settings.DISCORD_APP_ID, redirect_uri=settings.DISCORD_CALLBACK_URL, scope=SCOPES)
|
||||
url, state = oauth.authorization_url(AUTH_URL)
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def _process_callback_code(code):
|
||||
oauth = OAuth2Session(settings.DISCORD_APP_ID, redirect_uri=settings.DISCORD_CALLBACK_URL)
|
||||
token = oauth.fetch_token(TOKEN_URL, client_secret=settings.DISCORD_APP_SECRET, code=code)
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def add_user(code, groups, nickname=None):
|
||||
try:
|
||||
token = DiscordOAuthManager._process_callback_code(code)['access_token']
|
||||
logger.debug("Received token from OAuth")
|
||||
|
||||
custom_headers = {'accept': 'application/json', 'authorization': 'Bearer ' + token}
|
||||
path = DISCORD_URL + "/users/@me"
|
||||
r = requests.get(path, headers=custom_headers)
|
||||
logger.debug("Got status code %s after retrieving Discord profile" % r.status_code)
|
||||
r.raise_for_status()
|
||||
|
||||
user_id = r.json()['id']
|
||||
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/members/" + str(user_id)
|
||||
group_ids = [DiscordOAuthManager._group_name_to_id(DiscordOAuthManager._sanitize_group_name(g)) for g in
|
||||
groups]
|
||||
data = {
|
||||
'roles': group_ids,
|
||||
'access_token': token,
|
||||
}
|
||||
if nickname:
|
||||
data['nick'] = DiscordOAuthManager._sanitize_name(nickname)
|
||||
custom_headers['authorization'] = 'Bot ' + settings.DISCORD_BOT_TOKEN
|
||||
r = requests.put(path, headers=custom_headers, json=data)
|
||||
logger.debug("Got status code %s after joining Discord server" % r.status_code)
|
||||
r.raise_for_status()
|
||||
|
||||
logger.info("Added Discord user ID %s to server." % user_id)
|
||||
return user_id
|
||||
except:
|
||||
logger.exception("Failed to add Discord user")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@api_backoff
|
||||
def update_nickname(user_id, nickname):
|
||||
nickname = DiscordOAuthManager._sanitize_name(nickname)
|
||||
custom_headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
data = {'nick': nickname}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/members/" + str(user_id)
|
||||
r = requests.patch(path, headers=custom_headers, json=data)
|
||||
logger.debug("Got status code %s after setting nickname for Discord user ID %s (%s)" % (
|
||||
r.status_code, user_id, nickname))
|
||||
if r.status_code == 404:
|
||||
logger.warn("Discord user ID %s could not be found in server." % user_id)
|
||||
return True
|
||||
r.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_user(user_id):
|
||||
try:
|
||||
custom_headers = {'accept': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/members/" + str(user_id)
|
||||
r = requests.delete(path, headers=custom_headers)
|
||||
logger.debug("Got status code %s after removing Discord user ID %s" % (r.status_code, user_id))
|
||||
if r.status_code == 404:
|
||||
logger.warn("Discord user ID %s already left the server." % user_id)
|
||||
return True
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except:
|
||||
logger.exception("Failed to remove Discord user ID %s" % user_id)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_groups():
|
||||
custom_headers = {'accept': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/roles"
|
||||
r = requests.get(path, headers=custom_headers)
|
||||
logger.debug("Got status code %s after retrieving Discord roles" % r.status_code)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@staticmethod
|
||||
def _generate_cache_role_key(name):
|
||||
return 'DISCORD_ROLE_NAME__%s' % md5(str(name).encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _group_name_to_id(name):
|
||||
name = DiscordOAuthManager._sanitize_group_name(name)
|
||||
|
||||
def get_or_make_role():
|
||||
groups = DiscordOAuthManager._get_groups()
|
||||
for g in groups:
|
||||
if g['name'] == name:
|
||||
return g['id']
|
||||
return DiscordOAuthManager._create_group(name)['id']
|
||||
return cache.get_or_set(DiscordOAuthManager._generate_cache_role_key(name), get_or_make_role, GROUP_CACHE_MAX_AGE)
|
||||
|
||||
@staticmethod
|
||||
def __generate_role(name, **kwargs):
|
||||
custom_headers = {'accept': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/roles"
|
||||
data = {'name': name}
|
||||
data.update(kwargs)
|
||||
r = requests.post(path, headers=custom_headers, json=data)
|
||||
logger.debug("Received status code %s after generating new role." % r.status_code)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@staticmethod
|
||||
def __edit_role(role_id, **kwargs):
|
||||
custom_headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/roles/" + str(role_id)
|
||||
r = requests.patch(path, headers=custom_headers, json=kwargs)
|
||||
logger.debug("Received status code %s after editing role id %s" % (r.status_code, role_id))
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@staticmethod
|
||||
def _create_group(name):
|
||||
return DiscordOAuthManager.__generate_role(name)
|
||||
|
||||
@staticmethod
|
||||
def _get_user(user_id):
|
||||
custom_headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/members/" + str(user_id)
|
||||
r = requests.get(path, headers=custom_headers)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
@staticmethod
|
||||
def _get_user_roles(user_id):
|
||||
user = DiscordOAuthManager._get_user(user_id)
|
||||
return user['roles']
|
||||
|
||||
@staticmethod
|
||||
def _modify_user_role(user_id, role_id, method):
|
||||
custom_headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
path = DISCORD_URL + "/guilds/" + str(settings.DISCORD_GUILD_ID) + "/members/" + str(user_id) + "/roles/" + str(
|
||||
role_id)
|
||||
r = getattr(requests, method)(path, headers=custom_headers)
|
||||
r.raise_for_status()
|
||||
logger.debug("%s role %s for user %s" % (method, role_id, user_id))
|
||||
|
||||
@staticmethod
|
||||
@api_backoff
|
||||
def update_groups(user_id, groups):
|
||||
group_ids = [DiscordOAuthManager._group_name_to_id(DiscordOAuthManager._sanitize_group_name(g)) for g in groups]
|
||||
user_group_ids = DiscordOAuthManager._get_user_roles(user_id)
|
||||
for g in group_ids:
|
||||
if g not in user_group_ids:
|
||||
DiscordOAuthManager._modify_user_role(user_id, g, 'put')
|
||||
time.sleep(1) # we're gonna be hammering the API here
|
||||
for g in user_group_ids:
|
||||
if g not in group_ids:
|
||||
DiscordOAuthManager._modify_user_role(user_id, g, 'delete')
|
||||
time.sleep(1)
|
||||
175
allianceauth/services/modules/discord/managers.py
Normal file
175
allianceauth/services/modules/discord/managers.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from requests_oauthlib import OAuth2Session
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.services.hooks import NameFormatter
|
||||
|
||||
from . import __title__
|
||||
from .app_settings import (
|
||||
DISCORD_APP_ID,
|
||||
DISCORD_APP_SECRET,
|
||||
DISCORD_BOT_TOKEN,
|
||||
DISCORD_CALLBACK_URL,
|
||||
DISCORD_GUILD_ID,
|
||||
DISCORD_SYNC_NAMES
|
||||
)
|
||||
from .discord_client import DiscordClient, DiscordApiBackoff
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
class DiscordUserManager(models.Manager):
|
||||
"""Manager for DiscordUser"""
|
||||
|
||||
# full server admin
|
||||
BOT_PERMISSIONS = 0x00000008
|
||||
|
||||
# get user ID, accept invite
|
||||
SCOPES = [
|
||||
'identify',
|
||||
'guilds.join',
|
||||
]
|
||||
|
||||
def add_user(
|
||||
self,
|
||||
user: User,
|
||||
authorization_code: str,
|
||||
is_rate_limited: bool = True
|
||||
) -> bool:
|
||||
"""adds a new Discord user
|
||||
|
||||
Params:
|
||||
- user: Auth user to join
|
||||
- authorization_code: authorization code returns from oauth
|
||||
- is_rate_limited: When False will disable default rate limiting (use with care)
|
||||
|
||||
Returns: True on success, else False or raises exception
|
||||
"""
|
||||
try:
|
||||
nickname = self.user_formatted_nick(user) if DISCORD_SYNC_NAMES else None
|
||||
group_names = self.user_group_names(user)
|
||||
access_token = self._exchange_auth_code_for_token(authorization_code)
|
||||
user_client = DiscordClient(access_token, is_rate_limited=is_rate_limited)
|
||||
discord_user = user_client.current_user()
|
||||
user_id = discord_user['id']
|
||||
bot_client = self._bot_client(is_rate_limited=is_rate_limited)
|
||||
|
||||
if group_names:
|
||||
role_ids = self.model._guild_get_or_create_role_ids(
|
||||
bot_client, group_names
|
||||
)
|
||||
else:
|
||||
role_ids = None
|
||||
|
||||
created = bot_client.add_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
role_ids=role_ids,
|
||||
nick=nickname
|
||||
)
|
||||
if created is not False:
|
||||
if created is None:
|
||||
logger.debug(
|
||||
"User %s with Discord ID %s is already a member.",
|
||||
user,
|
||||
user_id,
|
||||
)
|
||||
self.update_or_create(
|
||||
user=user,
|
||||
defaults={
|
||||
'uid': user_id,
|
||||
'username': discord_user['username'][:32],
|
||||
'discriminator': discord_user['discriminator'][:4],
|
||||
'activated': now()
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
"Added user %s with Discord ID %s to Discord server", user, user_id
|
||||
)
|
||||
return True
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to add user %s with Discord ID %s to Discord server",
|
||||
user,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
|
||||
logger.exception(
|
||||
'Failed to add user %s to Discord server: %s', user, ex
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def user_formatted_nick(user: User) -> str:
|
||||
"""returns the name of the given users main character with name formatting
|
||||
or None if user has no main
|
||||
"""
|
||||
from .auth_hooks import DiscordService
|
||||
|
||||
if user.profile.main_character:
|
||||
return NameFormatter(DiscordService(), user).format_name()
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def user_group_names(user: User) -> list:
|
||||
"""returns list of group names plus state the given user is a member of"""
|
||||
return [group.name for group in user.groups.all()] + [user.profile.state.name]
|
||||
|
||||
def user_has_account(self, user: User) -> bool:
|
||||
"""Returns True if the user has an Discord account, else False
|
||||
|
||||
only checks locally, does not hit the API
|
||||
"""
|
||||
return True if hasattr(user, self.model.USER_RELATED_NAME) else False
|
||||
|
||||
@classmethod
|
||||
def generate_bot_add_url(cls):
|
||||
params = urlencode({
|
||||
'client_id': DISCORD_APP_ID,
|
||||
'scope': 'bot',
|
||||
'permissions': str(cls.BOT_PERMISSIONS)
|
||||
|
||||
})
|
||||
return f'{DiscordClient.OAUTH_BASE_URL}?{params}'
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_redirect_url(cls):
|
||||
oauth = OAuth2Session(
|
||||
DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL, scope=cls.SCOPES
|
||||
)
|
||||
url, state = oauth.authorization_url(DiscordClient.OAUTH_BASE_URL)
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def _exchange_auth_code_for_token(authorization_code: str) -> str:
|
||||
oauth = OAuth2Session(DISCORD_APP_ID, redirect_uri=DISCORD_CALLBACK_URL)
|
||||
token = oauth.fetch_token(
|
||||
DiscordClient.OAUTH_TOKEN_URL,
|
||||
client_secret=DISCORD_APP_SECRET,
|
||||
code=authorization_code
|
||||
)
|
||||
logger.debug("Received token from OAuth")
|
||||
return token['access_token']
|
||||
|
||||
@classmethod
|
||||
def server_name(cls):
|
||||
"""returns the name of the Discord server"""
|
||||
return cls._bot_client().guild_name(DISCORD_GUILD_ID)
|
||||
|
||||
@staticmethod
|
||||
def _bot_client(is_rate_limited: bool = True):
|
||||
"""returns a bot client for access to the Discord API"""
|
||||
return DiscordClient(DISCORD_BOT_TOKEN, is_rate_limited=is_rate_limited)
|
||||
@@ -0,0 +1,40 @@
|
||||
# Generated by Django 2.2.12 on 2020-05-10 19:59
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('discord', '0002_service_permissions'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='discorduser',
|
||||
name='activated',
|
||||
field=models.DateTimeField(blank=True, default=None, help_text='Date & time this service account was activated', null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='discorduser',
|
||||
name='discriminator',
|
||||
field=models.CharField(blank=True, default='', help_text="user's discriminator on Discord", max_length=4),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='discorduser',
|
||||
name='username',
|
||||
field=models.CharField(blank=True, db_index=True, default='', help_text="user's username on Discord", max_length=32),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='discorduser',
|
||||
name='uid',
|
||||
field=models.BigIntegerField(db_index=True, help_text="user's ID on Discord"),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='discorduser',
|
||||
name='user',
|
||||
field=models.OneToOneField(help_text='Auth user owning this Discord account', on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='discord', serialize=False, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
]
|
||||
@@ -1,18 +1,179 @@
|
||||
import logging
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy
|
||||
|
||||
from allianceauth.notifications import notify
|
||||
|
||||
from . import __title__
|
||||
from .app_settings import DISCORD_GUILD_ID
|
||||
from .discord_client import DiscordClient, DiscordApiBackoff
|
||||
from .managers import DiscordUserManager
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
|
||||
class DiscordUser(models.Model):
|
||||
user = models.OneToOneField(User,
|
||||
primary_key=True,
|
||||
on_delete=models.CASCADE,
|
||||
related_name='discord')
|
||||
uid = models.CharField(max_length=254)
|
||||
|
||||
def __str__(self):
|
||||
return "{} - {}".format(self.user.username, self.uid)
|
||||
USER_RELATED_NAME = 'discord'
|
||||
|
||||
user = models.OneToOneField(
|
||||
User,
|
||||
primary_key=True,
|
||||
on_delete=models.CASCADE,
|
||||
related_name=USER_RELATED_NAME,
|
||||
help_text='Auth user owning this Discord account'
|
||||
)
|
||||
uid = models.BigIntegerField(
|
||||
db_index=True,
|
||||
help_text='user\'s ID on Discord'
|
||||
)
|
||||
username = models.CharField(
|
||||
max_length=32,
|
||||
default='',
|
||||
blank=True,
|
||||
db_index=True,
|
||||
help_text='user\'s username on Discord'
|
||||
)
|
||||
discriminator = models.CharField(
|
||||
max_length=4,
|
||||
default='',
|
||||
blank=True,
|
||||
help_text='user\'s discriminator on Discord'
|
||||
)
|
||||
activated = models.DateTimeField(
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='Date & time this service account was activated'
|
||||
)
|
||||
|
||||
objects = DiscordUserManager()
|
||||
|
||||
class Meta:
|
||||
permissions = (
|
||||
("access_discord", u"Can access the Discord service"),
|
||||
("access_discord", "Can access the Discord service"),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.user.username} - {self.uid}'
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}(user=\'{self.user}\', uid={self.uid})'
|
||||
|
||||
def update_nickname(self) -> bool:
|
||||
"""Update nickname with formatted name of main character
|
||||
|
||||
Returns:
|
||||
- True on success
|
||||
- None if user is no longer a member of the Discord server
|
||||
- False on error or raises exception
|
||||
"""
|
||||
requested_nick = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
if requested_nick:
|
||||
client = DiscordUser.objects._bot_client()
|
||||
success = client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=self.uid,
|
||||
nick=requested_nick
|
||||
)
|
||||
if success:
|
||||
logger.info('Nickname for %s has been updated', self.user)
|
||||
else:
|
||||
logger.warning('Failed to update nickname for %s', self.user)
|
||||
return success
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
def update_groups(self) -> bool:
|
||||
"""update groups for a user based on his current group memberships.
|
||||
Will add or remove roles of a user as needed.
|
||||
|
||||
Returns:
|
||||
- True on success
|
||||
- None if user is no longer a member of the Discord server
|
||||
- False on error or raises exception
|
||||
"""
|
||||
role_names = DiscordUser.objects.user_group_names(self.user)
|
||||
client = DiscordUser.objects._bot_client()
|
||||
requested_role_ids = self._guild_get_or_create_role_ids(client, role_names)
|
||||
logger.debug(
|
||||
'Requested to update groups for user %s: %s', self.user, requested_role_ids
|
||||
)
|
||||
success = client.modify_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID,
|
||||
user_id=self.uid,
|
||||
role_ids=requested_role_ids
|
||||
)
|
||||
if success:
|
||||
logger.info('Groups for %s have been updated', self.user)
|
||||
else:
|
||||
logger.warning('Failed to update groups for %s', self.user)
|
||||
return success
|
||||
|
||||
def delete_user(
|
||||
self, notify_user: bool = False, is_rate_limited: bool = True
|
||||
) -> bool:
|
||||
"""Deletes the Discount user both on the server and locally
|
||||
|
||||
Params:
|
||||
- notify_user: When True will sent a notification to the user
|
||||
informing him about the deleting of his account
|
||||
- is_rate_limited: When False will disable default rate limiting (use with care)
|
||||
|
||||
Returns True when successful, otherwise False or raises exceptions
|
||||
Return None if user does no longer exist
|
||||
"""
|
||||
try:
|
||||
client = DiscordUser.objects._bot_client(is_rate_limited=is_rate_limited)
|
||||
success = client.remove_guild_member(
|
||||
guild_id=DISCORD_GUILD_ID, user_id=self.uid
|
||||
)
|
||||
if success is not False:
|
||||
deleted_count, _ = self.delete()
|
||||
if deleted_count > 0:
|
||||
if notify_user:
|
||||
notify(
|
||||
user=self.user,
|
||||
title=gettext_lazy('Discord Account Disabled'),
|
||||
message=gettext_lazy(
|
||||
'Your Discord account was disabeled automatically '
|
||||
'by Auth. If you think this was a mistake, '
|
||||
'please contact an admin.'
|
||||
),
|
||||
level='warning'
|
||||
)
|
||||
logger.info('Account for user %s was deleted.', self.user)
|
||||
return True
|
||||
else:
|
||||
logger.debug('Account for user %s was already deleted.', self.user)
|
||||
return None
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
'Failed to remove user %s from the Discord server', self.user
|
||||
)
|
||||
return False
|
||||
|
||||
except (HTTPError, ConnectionError, DiscordApiBackoff) as ex:
|
||||
logger.exception(
|
||||
'Failed to remove user %s from Discord server: %s', self.user, ex
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guild_get_or_create_role_ids(client: DiscordClient, role_names: list) -> list:
|
||||
"""wrapper for DiscordClient.match_guild_roles_to_names()
|
||||
that only returns the list of IDs
|
||||
"""
|
||||
return [
|
||||
x[0]['id'] for x in client.match_guild_roles_to_names(
|
||||
guild_id=DISCORD_GUILD_ID, role_names=role_names
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1,148 +1,187 @@
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from allianceauth.notifications import notify
|
||||
from celery import shared_task
|
||||
from celery import shared_task, chain
|
||||
from requests.exceptions import HTTPError
|
||||
from allianceauth.services.hooks import NameFormatter
|
||||
from .manager import DiscordOAuthManager, DiscordApiBackoff
|
||||
from .models import DiscordUser
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from allianceauth.services.tasks import QueueOnce
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from . import __title__
|
||||
from .app_settings import (
|
||||
DISCORD_TASKS_MAX_RETRIES, DISCORD_TASKS_RETRY_PAUSE, DISCORD_SYNC_NAMES
|
||||
)
|
||||
from .discord_client import DiscordApiBackoff
|
||||
from .models import DiscordUser
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
|
||||
class DiscordTasks:
|
||||
def __init__(self):
|
||||
pass
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
@classmethod
|
||||
def add_user(cls, user, code):
|
||||
groups = DiscordTasks.get_groups(user)
|
||||
nickname = None
|
||||
if settings.DISCORD_SYNC_NAMES:
|
||||
nickname = DiscordTasks.get_nickname(user)
|
||||
user_id = DiscordOAuthManager.add_user(code, groups, nickname=nickname)
|
||||
if user_id:
|
||||
discord_user = DiscordUser()
|
||||
discord_user.user = user
|
||||
discord_user.uid = user_id
|
||||
discord_user.save()
|
||||
return True
|
||||
return False
|
||||
# task priority of bulk tasks
|
||||
BULK_TASK_PRIORITY = 6
|
||||
|
||||
@classmethod
|
||||
def delete_user(cls, user, notify_user=False):
|
||||
if cls.has_account(user):
|
||||
logger.debug("User %s has discord account %s. Deleting." % (user, user.discord.uid))
|
||||
if DiscordOAuthManager.delete_user(user.discord.uid):
|
||||
user.discord.delete()
|
||||
if notify_user:
|
||||
notify(user, 'Discord Account Disabled', level='danger')
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def has_account(cls, user):
|
||||
"""
|
||||
Check if the user has an account (has a DiscordUser record)
|
||||
:param user: django.contrib.auth.models.User
|
||||
:return: bool
|
||||
"""
|
||||
@shared_task(
|
||||
bind=True, name='discord.update_groups', base=QueueOnce, max_retries=None
|
||||
)
|
||||
def update_groups(self, user_pk: int) -> None:
|
||||
"""Update roles on Discord for given user according to his current groups
|
||||
|
||||
Params:
|
||||
- user_pk: PK of given user
|
||||
"""
|
||||
_task_perform_user_action(self, user_pk, 'update_groups')
|
||||
|
||||
|
||||
@shared_task(
|
||||
bind=True, name='discord.update_nickname', base=QueueOnce, max_retries=None
|
||||
)
|
||||
def update_nickname(self, user_pk: int) -> None:
|
||||
"""Set nickname on Discord for given user to his main character name
|
||||
|
||||
Params:
|
||||
- user_pk: PK of given user
|
||||
"""
|
||||
_task_perform_user_action(self, user_pk, 'update_nickname')
|
||||
|
||||
|
||||
@shared_task(
|
||||
bind=True, name='discord.delete_user', base=QueueOnce, max_retries=None
|
||||
)
|
||||
def delete_user(self, user_pk: int, notify_user: bool = False) -> None:
|
||||
"""Delete Discord user
|
||||
|
||||
Params:
|
||||
- user_pk: PK of given user
|
||||
"""
|
||||
_task_perform_user_action(self, user_pk, 'delete_user', notify_user=notify_user)
|
||||
|
||||
|
||||
def _task_perform_user_action(self, user_pk: int, method: str, **kwargs) -> None:
|
||||
"""perform a user related action incl. managing all exceptions"""
|
||||
logger.debug("Starting %s for user with pk %s", method, user_pk)
|
||||
user = User.objects.get(pk=user_pk)
|
||||
if DiscordUser.objects.user_has_account(user):
|
||||
logger.info("Running %s for user %s", method, user)
|
||||
try:
|
||||
user.discord
|
||||
except ObjectDoesNotExist:
|
||||
return False
|
||||
success = getattr(user.discord, method)(**kwargs)
|
||||
|
||||
except DiscordApiBackoff as bo:
|
||||
logger.info(
|
||||
"API back off for %s wth user %s due to %r, retrying in %s seconds",
|
||||
method,
|
||||
user,
|
||||
bo,
|
||||
bo.retry_after_seconds
|
||||
)
|
||||
raise self.retry(countdown=bo.retry_after_seconds)
|
||||
|
||||
except AttributeError:
|
||||
raise ValueError(f'{method} not a valid method for DiscordUser: %r')
|
||||
|
||||
except (HTTPError, ConnectionError):
|
||||
logger.warning(
|
||||
'%s failed for user %s, retrying in %d secs',
|
||||
method,
|
||||
user,
|
||||
DISCORD_TASKS_RETRY_PAUSE,
|
||||
exc_info=True
|
||||
)
|
||||
if self.request.retries < DISCORD_TASKS_MAX_RETRIES:
|
||||
raise self.retry(countdown=DISCORD_TASKS_RETRY_PAUSE)
|
||||
else:
|
||||
logger.error(
|
||||
'%s failed for user %s after max retries',
|
||||
method,
|
||||
user,
|
||||
exc_info=True
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
'%s for %s failed due to unexpected exception',
|
||||
method,
|
||||
user,
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
else:
|
||||
return True
|
||||
if success is None and method != 'delete_user':
|
||||
delete_user.delay(user.pk, notify_user=True)
|
||||
|
||||
@staticmethod
|
||||
@shared_task(bind=True, name='discord.update_groups', base=QueueOnce)
|
||||
def update_groups(self, pk):
|
||||
user = User.objects.get(pk=pk)
|
||||
logger.debug("Updating discord groups for user %s" % user)
|
||||
if DiscordTasks.has_account(user):
|
||||
groups = DiscordTasks.get_groups(user)
|
||||
logger.debug("Updating user %s discord groups to %s" % (user, groups))
|
||||
try:
|
||||
DiscordOAuthManager.update_groups(user.discord.uid, groups)
|
||||
except DiscordApiBackoff as bo:
|
||||
logger.info("Discord group sync API back off for %s, "
|
||||
"retrying in %s seconds" % (user, bo.retry_after_seconds))
|
||||
raise self.retry(countdown=bo.retry_after_seconds)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 404:
|
||||
try:
|
||||
if e.response.json()['code'] == 10007:
|
||||
# user has left the server
|
||||
DiscordTasks.delete_user(user)
|
||||
return
|
||||
finally:
|
||||
raise e
|
||||
except Exception as e:
|
||||
if self:
|
||||
logger.exception("Discord group sync failed for %s, retrying in 10 mins" % user)
|
||||
raise self.retry(countdown=60 * 10)
|
||||
else:
|
||||
# Rethrow
|
||||
raise e
|
||||
logger.debug("Updated user %s discord groups." % user)
|
||||
else:
|
||||
logger.debug("User does not have a discord account, skipping")
|
||||
else:
|
||||
logger.debug(
|
||||
'User %s does not have a discord account, skipping %s', user, method
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@shared_task(name='discord.update_all_groups')
|
||||
def update_all_groups():
|
||||
logger.debug("Updating ALL discord groups")
|
||||
for discord_user in DiscordUser.objects.exclude(uid__exact=''):
|
||||
DiscordTasks.update_groups.delay(discord_user.user.pk)
|
||||
|
||||
@staticmethod
|
||||
@shared_task(bind=True, name='discord.update_nickname', base=QueueOnce)
|
||||
def update_nickname(self, pk):
|
||||
user = User.objects.get(pk=pk)
|
||||
logger.debug("Updating discord nickname for user %s" % user)
|
||||
if DiscordTasks.has_account(user):
|
||||
if user.profile.main_character:
|
||||
character = user.profile.main_character
|
||||
logger.debug("Updating user %s discord nickname to %s" % (user, character.character_name))
|
||||
try:
|
||||
DiscordOAuthManager.update_nickname(user.discord.uid, DiscordTasks.get_nickname(user))
|
||||
except DiscordApiBackoff as bo:
|
||||
logger.info("Discord nickname update API back off for %s, "
|
||||
"retrying in %s seconds" % (user, bo.retry_after_seconds))
|
||||
raise self.retry(countdown=bo.retry_after_seconds)
|
||||
except Exception as e:
|
||||
if self:
|
||||
logger.exception("Discord nickname sync failed for %s, retrying in 10 mins" % user)
|
||||
raise self.retry(countdown=60 * 10)
|
||||
else:
|
||||
# Rethrow
|
||||
raise e
|
||||
logger.debug("Updated user %s discord nickname." % user)
|
||||
else:
|
||||
logger.debug("User %s does not have a main character" % user)
|
||||
else:
|
||||
logger.debug("User %s does not have a discord account" % user)
|
||||
@shared_task(name='discord.update_all_groups')
|
||||
def update_all_groups() -> None:
|
||||
"""Update roles for all known users with a Discord account."""
|
||||
discord_users_qs = DiscordUser.objects.all()
|
||||
_bulk_update_groups_for_users(discord_users_qs)
|
||||
|
||||
@staticmethod
|
||||
@shared_task(name='discord.update_all_nicknames')
|
||||
def update_all_nicknames():
|
||||
logger.debug("Updating ALL discord nicknames")
|
||||
for discord_user in DiscordUser.objects.exclude(uid__exact=''):
|
||||
DiscordTasks.update_nickname.delay(discord_user.user.pk)
|
||||
|
||||
@classmethod
|
||||
def disable(cls):
|
||||
DiscordUser.objects.all().delete()
|
||||
@shared_task(name='discord.update_groups_bulk')
|
||||
def update_groups_bulk(user_pks: list) -> None:
|
||||
"""Update roles for list of users with a Discord account in bulk."""
|
||||
discord_users_qs = DiscordUser.objects\
|
||||
.filter(user__pk__in=user_pks)\
|
||||
.select_related()
|
||||
_bulk_update_groups_for_users(discord_users_qs)
|
||||
|
||||
@staticmethod
|
||||
def get_nickname(user):
|
||||
from .auth_hooks import DiscordService
|
||||
return NameFormatter(DiscordService(), user).format_name()
|
||||
|
||||
@staticmethod
|
||||
def get_groups(user):
|
||||
return [g.name for g in user.groups.all()] + [user.profile.state.name]
|
||||
def _bulk_update_groups_for_users(discord_users_qs: QuerySet) -> None:
|
||||
logger.info(
|
||||
"Starting to bulk update discord roles for %d users", discord_users_qs.count()
|
||||
)
|
||||
update_groups_chain = list()
|
||||
for discord_user in discord_users_qs:
|
||||
update_groups_chain.append(update_groups.si(discord_user.user.pk))
|
||||
|
||||
chain(update_groups_chain).apply_async(priority=BULK_TASK_PRIORITY)
|
||||
|
||||
|
||||
@shared_task(name='discord.update_all_nicknames')
|
||||
def update_all_nicknames() -> None:
|
||||
"""Update nicknames for all known users with a Discord account."""
|
||||
discord_users_qs = DiscordUser.objects.all()
|
||||
_bulk_update_nicknames_for_users(discord_users_qs)
|
||||
|
||||
|
||||
@shared_task(name='discord.update_nicknames_bulk')
|
||||
def update_nicknames_bulk(user_pks: list) -> None:
|
||||
"""Update nicknames for list of users with a Discord account in bulk."""
|
||||
discord_users_qs = DiscordUser.objects\
|
||||
.filter(user__pk__in=user_pks)\
|
||||
.select_related()
|
||||
_bulk_update_nicknames_for_users(discord_users_qs)
|
||||
|
||||
|
||||
def _bulk_update_nicknames_for_users(discord_users_qs: QuerySet) -> None:
|
||||
logger.info(
|
||||
"Starting to bulk update discord nicknames for %d users",
|
||||
discord_users_qs.count()
|
||||
)
|
||||
update_nicknames_chain = list()
|
||||
for discord_user in discord_users_qs:
|
||||
update_nicknames_chain.append(update_nickname.si(discord_user.user.pk))
|
||||
|
||||
chain(update_nicknames_chain).apply_async(priority=BULK_TASK_PRIORITY)
|
||||
|
||||
|
||||
@shared_task(name='discord.update_all')
|
||||
def update_all() -> None:
|
||||
"""Updates groups and nicknames (when activated) for all users."""
|
||||
discord_users_qs = DiscordUser.objects.all()
|
||||
logger.info(
|
||||
'Starting to bulk update all %s Discord users', discord_users_qs.count()
|
||||
)
|
||||
update_all_chain = list()
|
||||
for discord_user in discord_users_qs:
|
||||
update_all_chain.append(update_groups.si(discord_user.user.pk))
|
||||
if DISCORD_SYNC_NAMES:
|
||||
update_all_chain.append(update_nickname.si(discord_user.user.pk))
|
||||
|
||||
chain(update_all_chain).apply_async(priority=BULK_TASK_PRIORITY)
|
||||
|
||||
@@ -3,10 +3,18 @@
|
||||
|
||||
<tr>
|
||||
<td class="text-center">Discord</td>
|
||||
<td class="text-center"></td>
|
||||
<td class="text-center"><a href="https://discordapp.com/channels/{{ DISCORD_SERVER_ID }}/{{ DISCORD_SERVER_ID}}">https://discordapp.com</a></td>
|
||||
<td class="text-center">
|
||||
{% if not discord_uid %}
|
||||
{% if not user_has_account %}
|
||||
(not activated)
|
||||
{% else %}
|
||||
{{discord_username}}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="text-center">
|
||||
{{server_name}}
|
||||
</td>
|
||||
<td class="text-center">
|
||||
{% if not user_has_account %}
|
||||
<a href="{% url 'discord:activate' %}" title="Activate" class="btn btn-warning">
|
||||
<span class="glyphicon glyphicon-ok"></span>
|
||||
</a>
|
||||
@@ -20,7 +28,9 @@
|
||||
{% endif %}
|
||||
{% if request.user.is_superuser %}
|
||||
<div class="text-center" style="padding-top:5px;">
|
||||
<a type="button" class="btn btn-success" href="{% url 'discord:add_bot' %}">{% trans "Link Discord Server" %}</a>
|
||||
<a type="button" class="btn btn-success" href="{% url 'discord:add_bot' %}">
|
||||
{% trans "Link Discord Server" %}
|
||||
</a>
|
||||
</div>
|
||||
{% endif %}
|
||||
</td>
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
from django.contrib.auth.models import User, Group, Permission
|
||||
from django.contrib.auth.models import Group, Permission
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
DEFAULT_AUTH_GROUP = 'Member'
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord'
|
||||
|
||||
def add_permissions():
|
||||
TEST_GUILD_ID = 123456789012345678
|
||||
TEST_USER_ID = 198765432012345678
|
||||
TEST_USER_NAME = 'Peter Parker'
|
||||
TEST_MAIN_NAME = 'Spiderman'
|
||||
TEST_MAIN_ID = 1005
|
||||
|
||||
|
||||
def add_permissions_to_members():
|
||||
permission = Permission.objects.get(codename='access_discord')
|
||||
members = Group.objects.get_or_create(name=DEFAULT_AUTH_GROUP)[0]
|
||||
AuthUtils.add_permissions_to_groups([permission], [members])
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.contrib import admin
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth.models import User
|
||||
from django.utils.timezone import now
|
||||
|
||||
from allianceauth.authentication.models import CharacterOwnership
|
||||
from allianceauth.eveonline.models import (
|
||||
@@ -18,17 +16,21 @@ from ....admin import (
|
||||
MainCorporationsFilter,
|
||||
MainAllianceFilter
|
||||
)
|
||||
from ..admin import (
|
||||
DiscordUser,
|
||||
DiscordUserAdmin
|
||||
)
|
||||
from ..admin import DiscordUserAdmin
|
||||
from ..models import DiscordUser
|
||||
|
||||
|
||||
class TestDiscordUserAdmin(TestCase):
|
||||
class TestDataMixin(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
EveCharacter.objects.all().delete()
|
||||
EveCorporationInfo.objects.all().delete()
|
||||
EveAllianceInfo.objects.all().delete()
|
||||
User.objects.all().delete()
|
||||
DiscordUser.objects.all().delete()
|
||||
|
||||
# user 1 - corp and alliance, normal user
|
||||
cls.character_1 = EveCharacter.objects.create(
|
||||
@@ -83,7 +85,10 @@ class TestDiscordUserAdmin(TestCase):
|
||||
cls.user_1.profile.save()
|
||||
DiscordUser.objects.create(
|
||||
user=cls.user_1,
|
||||
uid=1001
|
||||
uid=1001,
|
||||
username='Bruce Wayne',
|
||||
discriminator='1234',
|
||||
activated=now()
|
||||
)
|
||||
|
||||
# user 2 - corp only, staff
|
||||
@@ -156,18 +161,20 @@ class TestDiscordUserAdmin(TestCase):
|
||||
uid=1003
|
||||
)
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.factory = RequestFactory()
|
||||
self.modeladmin = DiscordUserAdmin(
|
||||
model=DiscordUser, admin_site=AdminSite()
|
||||
)
|
||||
|
||||
# column rendering
|
||||
|
||||
class TestColumnRendering(TestDataMixin, TestCase):
|
||||
|
||||
def test_user_profile_pic_u1(self):
|
||||
expected = ('<img src="https://images.evetech.net/characters/1001/'
|
||||
'portrait?size=32" class="img-circle">')
|
||||
expected = (
|
||||
'<img src="https://images.evetech.net/characters/1001/'
|
||||
'portrait?size=32" class="img-circle">'
|
||||
)
|
||||
self.assertEqual(user_profile_pic(self.user_1.discord), expected)
|
||||
|
||||
def test_user_profile_pic_u3(self):
|
||||
@@ -204,9 +211,26 @@ class TestDiscordUserAdmin(TestCase):
|
||||
result = user_main_organization(self.user_3.discord)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_uid(self):
|
||||
expected = 1001
|
||||
result = self.modeladmin._uid(self.user_1.discord)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_username_when_defined(self):
|
||||
expected = 'Bruce Wayne#1234'
|
||||
result = self.modeladmin._username(self.user_1.discord)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_username_when_not_defined(self):
|
||||
expected = ''
|
||||
result = self.modeladmin._username(self.user_2.discord)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# actions
|
||||
|
||||
# filters
|
||||
|
||||
class TestFilters(TestDataMixin, TestCase):
|
||||
|
||||
def test_filter_main_corporations(self):
|
||||
|
||||
class DiscordUserAdminTest(ServicesUserAdmin):
|
||||
@@ -228,8 +252,7 @@ class TestDiscordUserAdmin(TestCase):
|
||||
|
||||
# Make sure the correct queryset is returned
|
||||
request = self.factory.get(
|
||||
'/',
|
||||
{'main_corporation_id__exact': self.character_1.corporation_id}
|
||||
'/', {'main_corporation_id__exact': self.character_1.corporation_id}
|
||||
)
|
||||
request.user = self.user_1
|
||||
changelist = my_modeladmin.get_changelist_instance(request)
|
||||
@@ -250,19 +273,17 @@ class TestDiscordUserAdmin(TestCase):
|
||||
changelist = my_modeladmin.get_changelist_instance(request)
|
||||
filters = changelist.get_filters(request)
|
||||
filterspec = filters[0][0]
|
||||
expected = [
|
||||
('3001', 'Wayne Enterprises'),
|
||||
expected = [
|
||||
('3001', 'Wayne Enterprises'),
|
||||
]
|
||||
self.assertEqual(filterspec.lookup_choices, expected)
|
||||
|
||||
# Make sure the correct queryset is returned
|
||||
request = self.factory.get(
|
||||
'/',
|
||||
{'main_alliance_id__exact': self.character_1.alliance_id}
|
||||
'/', {'main_alliance_id__exact': self.character_1.alliance_id}
|
||||
)
|
||||
request.user = self.user_1
|
||||
changelist = my_modeladmin.get_changelist_instance(request)
|
||||
queryset = changelist.get_queryset(request)
|
||||
expected = [self.user_1.discord]
|
||||
self.assertSetEqual(set(queryset), set(expected))
|
||||
|
||||
140
allianceauth/services/modules/discord/tests/test_auth_hooks.py
Normal file
140
allianceauth/services/modules/discord/tests/test_auth_hooks.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import TEST_USER_NAME, TEST_USER_ID, add_permissions_to_members, MODULE_PATH
|
||||
from ..auth_hooks import DiscordService
|
||||
from ..models import DiscordUser, DiscordClient
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.auth_hooks', __file__)
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
class TestDiscordService(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_member(TEST_USER_NAME)
|
||||
DiscordUser.objects.create(
|
||||
user=self.member,
|
||||
uid=TEST_USER_ID,
|
||||
username=TEST_USER_NAME,
|
||||
discriminator='1234'
|
||||
)
|
||||
self.none_member = AuthUtils.create_user('Lex Luther')
|
||||
self.service = DiscordService
|
||||
add_permissions_to_members()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_service_enabled(self):
|
||||
service = self.service()
|
||||
self.assertTrue(service.service_active_for_user(self.member))
|
||||
self.assertFalse(service.service_active_for_user(self.none_member))
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_all_groups')
|
||||
def test_update_all_groups(self, mock_update_all_groups):
|
||||
service = self.service()
|
||||
service.update_all_groups()
|
||||
self.assertTrue(mock_update_all_groups.delay.called)
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_groups_bulk')
|
||||
def test_update_groups_bulk(self, mock_update_groups_bulk):
|
||||
service = self.service()
|
||||
service.update_groups_bulk([self.member])
|
||||
self.assertTrue(mock_update_groups_bulk.delay.called)
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_groups')
|
||||
def test_update_groups_for_member(self, mock_update_groups):
|
||||
service = self.service()
|
||||
service.update_groups(self.member)
|
||||
self.assertTrue(mock_update_groups.apply_async.called)
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_groups')
|
||||
def test_update_groups_for_none_member(self, mock_update_groups):
|
||||
service = self.service()
|
||||
service.update_groups(self.none_member)
|
||||
self.assertFalse(mock_update_groups.apply_async.called)
|
||||
|
||||
@patch(MODULE_PATH + '.models.notify')
|
||||
@patch(MODULE_PATH + '.tasks.DiscordUser')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_validate_user(
|
||||
self, mock_DiscordClient, mock_DiscordUser, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
|
||||
# Test member is not deleted
|
||||
service = self.service()
|
||||
service.validate_user(self.member)
|
||||
self.assertTrue(DiscordUser.objects.filter(user=self.member).exists())
|
||||
|
||||
# Test none member is deleted
|
||||
DiscordUser.objects.create(user=self.none_member, uid=TEST_USER_ID)
|
||||
service.validate_user(self.none_member)
|
||||
self.assertFalse(DiscordUser.objects.filter(user=self.none_member).exists())
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_nickname')
|
||||
def test_sync_nickname(self, mock_update_nickname):
|
||||
service = self.service()
|
||||
service.sync_nickname(self.member)
|
||||
self.assertTrue(mock_update_nickname.apply_async.called)
|
||||
|
||||
@patch(MODULE_PATH + '.tasks.update_nicknames_bulk')
|
||||
def test_sync_nicknames_bulk(self, mock_update_nicknames_bulk):
|
||||
service = self.service()
|
||||
service.sync_nicknames_bulk([self.member])
|
||||
self.assertTrue(mock_update_nicknames_bulk.delay.called)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_delete_user_is_member(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
|
||||
service = self.service()
|
||||
service.delete_user(self.member)
|
||||
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertFalse(DiscordUser.objects.filter(user=self.member).exists())
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_delete_user_is_not_member(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
|
||||
service = self.service()
|
||||
service.delete_user(self.none_member)
|
||||
|
||||
self.assertFalse(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_render_services_ctrl_with_username(self, mock_DiscordClient):
|
||||
service = self.service()
|
||||
request = self.factory.get('/services/')
|
||||
request.user = self.member
|
||||
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertTemplateUsed(service.service_ctrl_template)
|
||||
self.assertIn('/discord/reset/', response)
|
||||
self.assertIn('/discord/deactivate/', response)
|
||||
|
||||
# Test register becomes available
|
||||
self.member.discord.delete()
|
||||
self.member.refresh_from_db()
|
||||
request.user = self.member
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertIn('/discord/activate/', response)
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
def test_render_services_ctrl_wo_username(self, mock_DiscordClient):
|
||||
my_member = AuthUtils.create_member('John Doe')
|
||||
DiscordUser.objects.create(user=my_member, uid=111222333)
|
||||
service = self.service()
|
||||
request = self.factory.get('/services/')
|
||||
request.user = my_member
|
||||
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertTemplateUsed(service.service_ctrl_template)
|
||||
self.assertIn('/discord/reset/', response)
|
||||
self.assertIn('/discord/deactivate/', response)
|
||||
@@ -1,127 +0,0 @@
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from ..auth_hooks import DiscordService
|
||||
from ..models import DiscordUser
|
||||
from ..tasks import DiscordTasks
|
||||
from ..manager import DiscordOAuthManager
|
||||
|
||||
from . import DEFAULT_AUTH_GROUP, add_permissions, MODULE_PATH
|
||||
|
||||
|
||||
class DiscordHooksTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.member = 'member_user'
|
||||
member = AuthUtils.create_member(self.member)
|
||||
DiscordUser.objects.create(user=member, uid='12345')
|
||||
self.none_user = 'none_user'
|
||||
none_user = AuthUtils.create_user(self.none_user)
|
||||
self.service = DiscordService
|
||||
add_permissions()
|
||||
|
||||
def test_has_account(self):
|
||||
member = User.objects.get(username=self.member)
|
||||
none_user = User.objects.get(username=self.none_user)
|
||||
self.assertTrue(DiscordTasks.has_account(member))
|
||||
self.assertFalse(DiscordTasks.has_account(none_user))
|
||||
|
||||
def test_service_enabled(self):
|
||||
service = self.service()
|
||||
member = User.objects.get(username=self.member)
|
||||
none_user = User.objects.get(username=self.none_user)
|
||||
|
||||
self.assertTrue(service.service_active_for_user(member))
|
||||
self.assertFalse(service.service_active_for_user(none_user))
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_update_all_groups(self, manager):
|
||||
service = self.service()
|
||||
service.update_all_groups()
|
||||
# Check member and blue user have groups updated
|
||||
self.assertTrue(manager.update_groups.called)
|
||||
self.assertEqual(manager.update_groups.call_count, 1)
|
||||
|
||||
def test_update_groups(self):
|
||||
# Check member has Member group updated
|
||||
with mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager') as manager:
|
||||
service = self.service()
|
||||
member = User.objects.get(username=self.member)
|
||||
AuthUtils.disconnect_signals()
|
||||
service.update_groups(member)
|
||||
self.assertTrue(manager.update_groups.called)
|
||||
args, kwargs = manager.update_groups.call_args
|
||||
user_id, groups = args
|
||||
self.assertIn(DEFAULT_AUTH_GROUP, groups)
|
||||
self.assertEqual(user_id, member.discord.uid)
|
||||
|
||||
# Check none user does not have groups updated
|
||||
with mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager') as manager:
|
||||
service = self.service()
|
||||
none_user = User.objects.get(username=self.none_user)
|
||||
service.update_groups(none_user)
|
||||
self.assertFalse(manager.update_groups.called)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_validate_user(self, manager):
|
||||
service = self.service()
|
||||
# Test member is not deleted
|
||||
member = User.objects.get(username=self.member)
|
||||
service.validate_user(member)
|
||||
self.assertTrue(member.discord)
|
||||
|
||||
# Test none user is deleted
|
||||
none_user = User.objects.get(username=self.none_user)
|
||||
DiscordUser.objects.create(user=none_user, uid='abc123')
|
||||
service.validate_user(none_user)
|
||||
self.assertTrue(manager.delete_user.called)
|
||||
with self.assertRaises(ObjectDoesNotExist):
|
||||
none_discord = User.objects.get(username=self.none_user).discord
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_sync_nickname(self, manager):
|
||||
service = self.service()
|
||||
member = User.objects.get(username=self.member)
|
||||
AuthUtils.add_main_character(member, 'test user', '12345', corp_ticker='AAUTH')
|
||||
|
||||
service.sync_nickname(member)
|
||||
|
||||
self.assertTrue(manager.update_nickname.called)
|
||||
args, kwargs = manager.update_nickname.call_args
|
||||
self.assertEqual(args[0], member.discord.uid)
|
||||
self.assertEqual(args[1], 'test user')
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_delete_user(self, manager):
|
||||
member = User.objects.get(username=self.member)
|
||||
|
||||
service = self.service()
|
||||
result = service.delete_user(member)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(manager.delete_user.called)
|
||||
with self.assertRaises(ObjectDoesNotExist):
|
||||
discord_user = User.objects.get(username=self.member).discord
|
||||
|
||||
def test_render_services_ctrl(self):
|
||||
service = self.service()
|
||||
member = User.objects.get(username=self.member)
|
||||
request = RequestFactory().get('/services/')
|
||||
request.user = member
|
||||
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertTemplateUsed(service.service_ctrl_template)
|
||||
self.assertIn('/discord/reset/', response)
|
||||
self.assertIn('/discord/deactivate/', response)
|
||||
|
||||
# Test register becomes available
|
||||
member.discord.delete()
|
||||
member = User.objects.get(username=self.member)
|
||||
request.user = member
|
||||
response = service.render_services_ctrl(request)
|
||||
self.assertIn('/discord/activate/', response)
|
||||
|
||||
# TODO: Test update nicknames
|
||||
@@ -0,0 +1,62 @@
|
||||
from django_webtest import WebTest
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.shortcuts import reverse
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import (
|
||||
add_permissions_to_members,
|
||||
MODULE_PATH,
|
||||
TEST_USER_NAME,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID
|
||||
)
|
||||
|
||||
|
||||
class TestServiceUserActivation(WebTest):
|
||||
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_member(TEST_USER_NAME)
|
||||
AuthUtils.add_main_character_2(
|
||||
self.member,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
disconnect_signals=True
|
||||
)
|
||||
add_permissions_to_members()
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.add_user')
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session')
|
||||
def test_user_activation(
|
||||
self, mock_OAuth2Session, mock_add_user, mock_messages
|
||||
):
|
||||
authentication_code = 'auth_code'
|
||||
mock_add_user.return_value = True
|
||||
oauth_url = 'https://www.example.com/oauth'
|
||||
state = ''
|
||||
mock_OAuth2Session.return_value.authorization_url.return_value = \
|
||||
oauth_url, state
|
||||
|
||||
# login
|
||||
self.app.set_user(self.member)
|
||||
|
||||
# click activate on the service page
|
||||
response = self.app.get(reverse('discord:activate'))
|
||||
|
||||
# check we got a redirect to Discord OAuth
|
||||
self.assertRedirects(
|
||||
response, expected_url=oauth_url, fetch_redirect_response=False
|
||||
)
|
||||
|
||||
# simulate Discord callback
|
||||
response = self.app.get(
|
||||
reverse('discord:callback'), params={'code': authentication_code}
|
||||
)
|
||||
|
||||
# user was added to Discord
|
||||
self.assertTrue(mock_add_user.called)
|
||||
|
||||
# user got a success message
|
||||
self.assertTrue(mock_messages.success.called)
|
||||
@@ -1,244 +1,356 @@
|
||||
import json
|
||||
from unittest.mock import patch, Mock
|
||||
import urllib
|
||||
import datetime
|
||||
import requests_mock
|
||||
from unittest import mock
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.contrib.auth.models import Group, User
|
||||
from django.test import TestCase
|
||||
from django.conf import settings
|
||||
|
||||
from ..manager import DiscordOAuthManager
|
||||
from .. import manager
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import DEFAULT_AUTH_GROUP, add_permissions, MODULE_PATH
|
||||
from . import (
|
||||
TEST_GUILD_ID,
|
||||
TEST_USER_NAME,
|
||||
TEST_USER_ID,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
MODULE_PATH
|
||||
)
|
||||
from ..app_settings import (
|
||||
DISCORD_APP_ID,
|
||||
DISCORD_APP_SECRET,
|
||||
DISCORD_CALLBACK_URL,
|
||||
)
|
||||
from ..discord_client import DiscordClient, DiscordApiBackoff
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
|
||||
class DiscordManagerTestCase(TestCase):
|
||||
logger = set_logger_to_file(MODULE_PATH + '.managers', __file__)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_GUILD_ID', TEST_GUILD_ID)
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects._exchange_auth_code_for_token')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.model._guild_get_or_create_role_ids')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_formatted_nick')
|
||||
class TestAddUser(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.user_info = {
|
||||
'id': TEST_USER_ID,
|
||||
'name': TEST_USER_NAME,
|
||||
'username': TEST_USER_NAME,
|
||||
'discriminator': '1234',
|
||||
}
|
||||
self.access_token = 'accesstoken'
|
||||
|
||||
def test_can_create_user_no_roles_no_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
def test__sanitize_group_name(self):
|
||||
test_group_name = str(10**103)
|
||||
group_name = DiscordOAuthManager._sanitize_group_name(test_group_name)
|
||||
def test_can_create_user_with_roles_no_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
role_ids = [1, 2, 3]
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = ['a', 'b', 'c']
|
||||
mock_guild_get_or_create_role_ids.return_value = role_ids
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertEqual(kwargs['role_ids'], role_ids)
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
self.assertEqual(group_name, test_group_name[:100])
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', True)
|
||||
def test_can_create_user_no_roles_with_nick(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertEqual(kwargs['nick'], TEST_MAIN_NAME)
|
||||
|
||||
def test_generate_Bot_add_url(self):
|
||||
bot_add_url = DiscordOAuthManager.generate_bot_add_url()
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_SYNC_NAMES', False)
|
||||
def test_can_create_user_no_roles_and_without_nick_if_turned_off(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = TEST_MAIN_NAME
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = []
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = True
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
args, kwargs = mock_DiscordClient.return_value.add_guild_member.call_args
|
||||
self.assertEqual(kwargs['guild_id'], TEST_GUILD_ID)
|
||||
self.assertEqual(kwargs['user_id'], TEST_USER_ID)
|
||||
self.assertEqual(kwargs['access_token'], self.access_token)
|
||||
self.assertIsNone(kwargs['role_ids'])
|
||||
self.assertIsNone(kwargs['nick'])
|
||||
|
||||
def test_can_activate_existing_guild_member(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = None
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
|
||||
def test_return_false_when_user_creation_fails(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.return_value = False
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
|
||||
auth_url = manager.AUTH_URL
|
||||
real_bot_add_url = '{}?client_id=appid&scope=bot&permissions={}'.format(auth_url, manager.BOT_PERMISSIONS)
|
||||
def test_return_false_when_on_api_backoff(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_DiscordClient.return_value.add_guild_member.side_effect = \
|
||||
DiscordApiBackoff(999)
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
|
||||
def test_return_false_on_http_error(
|
||||
self,
|
||||
mock_user_formatted_nick,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_exchange_auth_code_for_token,
|
||||
mock_DiscordClient
|
||||
):
|
||||
mock_user_formatted_nick.return_value = None
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = None
|
||||
mock_exchange_auth_code_for_token.return_value = self.access_token
|
||||
mock_DiscordClient.return_value.current_user.return_value = self.user_info
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_DiscordClient.return_value.add_guild_member.side_effect = mock_exception
|
||||
|
||||
result = DiscordUser.objects.add_user(self.user, authorization_code='abcdef')
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.add_guild_member.called)
|
||||
|
||||
|
||||
class TestOauthHelpers(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DISCORD_APP_ID', '123456')
|
||||
def test_generate_bot_add_url(self):
|
||||
bot_add_url = DiscordUser.objects.generate_bot_add_url()
|
||||
|
||||
auth_url = DiscordClient.OAUTH_BASE_URL
|
||||
real_bot_add_url = (
|
||||
f'{auth_url}?client_id=123456&scope=bot'
|
||||
f'&permissions={DiscordUser.objects.BOT_PERMISSIONS}'
|
||||
)
|
||||
self.assertEqual(bot_add_url, real_bot_add_url)
|
||||
|
||||
def test_generate_oauth_redirect_url(self):
|
||||
oauth_url = DiscordOAuthManager.generate_oauth_redirect_url()
|
||||
oauth_url = DiscordUser.objects.generate_oauth_redirect_url()
|
||||
|
||||
self.assertIn(manager.AUTH_URL, oauth_url)
|
||||
self.assertIn('+'.join(manager.SCOPES), oauth_url)
|
||||
self.assertIn(settings.DISCORD_APP_ID, oauth_url)
|
||||
self.assertIn(urllib.parse.quote_plus(settings.DISCORD_CALLBACK_URL), oauth_url)
|
||||
self.assertIn(DiscordClient.OAUTH_BASE_URL, oauth_url)
|
||||
self.assertIn('+'.join(DiscordUser.objects.SCOPES), oauth_url)
|
||||
self.assertIn(DISCORD_APP_ID, oauth_url)
|
||||
self.assertIn(urllib.parse.quote_plus(DISCORD_CALLBACK_URL), oauth_url)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.manager.OAuth2Session')
|
||||
def test__process_callback_code(self, oauth):
|
||||
@patch(MODULE_PATH + '.managers.OAuth2Session')
|
||||
def test_process_callback_code(self, oauth):
|
||||
instance = oauth.return_value
|
||||
instance.fetch_token.return_value = {'access_token': 'mywonderfultoken'}
|
||||
|
||||
token = DiscordOAuthManager._process_callback_code('12345')
|
||||
token = DiscordUser.objects._exchange_auth_code_for_token('12345')
|
||||
|
||||
self.assertTrue(oauth.called)
|
||||
args, kwargs = oauth.call_args
|
||||
self.assertEqual(args[0], settings.DISCORD_APP_ID)
|
||||
self.assertEqual(kwargs['redirect_uri'], settings.DISCORD_CALLBACK_URL)
|
||||
self.assertEqual(args[0], DISCORD_APP_ID)
|
||||
self.assertEqual(kwargs['redirect_uri'], DISCORD_CALLBACK_URL)
|
||||
self.assertTrue(instance.fetch_token.called)
|
||||
args, kwargs = instance.fetch_token.call_args
|
||||
self.assertEqual(args[0], manager.TOKEN_URL)
|
||||
self.assertEqual(kwargs['client_secret'], settings.DISCORD_APP_SECRET)
|
||||
self.assertEqual(args[0], DiscordClient.OAUTH_TOKEN_URL)
|
||||
self.assertEqual(kwargs['client_secret'], DISCORD_APP_SECRET)
|
||||
self.assertEqual(kwargs['code'], '12345')
|
||||
self.assertEqual(token['access_token'], 'mywonderfultoken')
|
||||
self.assertEqual(token, 'mywonderfultoken')
|
||||
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._process_callback_code')
|
||||
@requests_mock.Mocker()
|
||||
def test_add_user(self, oauth_token, m):
|
||||
# Arrange
|
||||
oauth_token.return_value = {'access_token': 'accesstoken'}
|
||||
|
||||
headers = {'accept': 'application/json', 'authorization': 'Bearer accesstoken'}
|
||||
class TestUserFormattedNick(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
def test_return_nick_when_user_has_main(self):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
result = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
expected = TEST_MAIN_NAME
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
m.register_uri('GET',
|
||||
manager.DISCORD_URL + "/users/@me",
|
||||
request_headers=headers,
|
||||
text=json.dumps({'id': "123456"}))
|
||||
def test_return_none_if_user_has_no_main(self):
|
||||
result = DiscordUser.objects.user_formatted_nick(self.user)
|
||||
self.assertIsNone(result)
|
||||
|
||||
headers = {'accept': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
|
||||
m.register_uri('PUT',
|
||||
manager.DISCORD_URL + '/guilds/' + str(settings.DISCORD_GUILD_ID) + '/members/123456',
|
||||
request_headers=headers,
|
||||
text='{}')
|
||||
class TestUserGroupNames(TestCase):
|
||||
|
||||
# Act
|
||||
return_value = DiscordOAuthManager.add_user('abcdef', [])
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.group_1 = Group.objects.create(name='Group 1')
|
||||
cls.group_2 = Group.objects.create(name='Group 2')
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(return_value, '123456')
|
||||
self.assertEqual(m.call_count, 2)
|
||||
def test_return_groups_and_state_names_for_user(self):
|
||||
self.user.groups.add(self.group_1)
|
||||
result = DiscordUser.objects.user_group_names(self.user)
|
||||
expected = ['Group 1', 'Member']
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
def test_return_state_only_if_user_has_no_groups(self):
|
||||
result = DiscordUser.objects.user_group_names(self.user)
|
||||
expected = ['Member']
|
||||
self.assertSetEqual(set(result), set(expected))
|
||||
|
||||
@requests_mock.Mocker()
|
||||
def test_delete_user(self, m):
|
||||
# Arrange
|
||||
headers = {'accept': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
|
||||
user_id = 12345
|
||||
request_url = '{}/guilds/{}/members/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id)
|
||||
m.register_uri('DELETE',
|
||||
request_url,
|
||||
request_headers=headers,
|
||||
text=json.dumps({}))
|
||||
class TestUserHasAccount(TestCase):
|
||||
|
||||
# Act
|
||||
result = DiscordOAuthManager.delete_user(user_id)
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
|
||||
# Assert
|
||||
self.assertTrue(result)
|
||||
def test_return_true_if_user_has_account(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
self.assertTrue(DiscordUser.objects.user_has_account(self.user))
|
||||
|
||||
###
|
||||
# Test 404 (already deleted)
|
||||
# Arrange
|
||||
m.register_uri('DELETE',
|
||||
request_url,
|
||||
request_headers=headers,
|
||||
status_code=404)
|
||||
def test_return_false_if_user_has_no_account(self):
|
||||
self.assertFalse(DiscordUser.objects.user_has_account(self.user))
|
||||
|
||||
# Act
|
||||
result = DiscordOAuthManager.delete_user(user_id)
|
||||
def test_return_false_if_user_does_not_exist(self):
|
||||
my_user = User(username='Dummy')
|
||||
self.assertFalse(DiscordUser.objects.user_has_account(my_user))
|
||||
|
||||
# Assert
|
||||
self.assertTrue(result)
|
||||
|
||||
###
|
||||
# Test 500 (some random API error)
|
||||
# Arrange
|
||||
m.register_uri('DELETE',
|
||||
request_url,
|
||||
request_headers=headers,
|
||||
status_code=500)
|
||||
|
||||
# Act
|
||||
result = DiscordOAuthManager.delete_user(user_id)
|
||||
|
||||
# Assert
|
||||
self.assertFalse(result)
|
||||
|
||||
@requests_mock.Mocker()
|
||||
def test_update_nickname(self, m):
|
||||
# Arrange
|
||||
headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
|
||||
user_id = 12345
|
||||
request_url = '{}/guilds/{}/members/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id)
|
||||
m.patch(request_url,
|
||||
request_headers=headers)
|
||||
|
||||
# Act
|
||||
result = DiscordOAuthManager.update_nickname(user_id, 'somenick')
|
||||
|
||||
# Assert
|
||||
self.assertTrue(result)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._get_user_roles')
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._get_groups')
|
||||
@requests_mock.Mocker()
|
||||
def test_update_groups(self, group_cache, user_roles, m):
|
||||
# Arrange
|
||||
groups = ['Member', 'Blue', 'SpecialGroup']
|
||||
|
||||
group_cache.return_value = [{'id': '111', 'name': 'Member'},
|
||||
{'id': '222', 'name': 'Blue'},
|
||||
{'id': '333', 'name': 'SpecialGroup'},
|
||||
{'id': '444', 'name': 'NotYourGroup'}]
|
||||
user_roles.return_value = ['444']
|
||||
|
||||
headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
user_id = 12345
|
||||
user_request_url = '{}/guilds/{}/members/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id)
|
||||
group_request_urls = ['{}/guilds/{}/members/{}/roles/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id, g['id']) for g in group_cache.return_value]
|
||||
|
||||
m.patch(user_request_url, request_headers=headers)
|
||||
[m.put(url, request_headers=headers) for url in group_request_urls[:-1]]
|
||||
m.delete(group_request_urls[-1], request_headers=headers)
|
||||
|
||||
# Act
|
||||
DiscordOAuthManager.update_groups(user_id, groups)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(len(m.request_history), 4, 'Must be 4 HTTP calls made')
|
||||
|
||||
@mock.patch(MODULE_PATH + '.manager.cache')
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._get_user_roles')
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._group_name_to_id')
|
||||
@requests_mock.Mocker()
|
||||
def test_update_groups_backoff(self, name_to_id, user_groups, djcache, m):
|
||||
# Arrange
|
||||
groups = ['Member']
|
||||
user_groups.return_value = []
|
||||
name_to_id.return_value = '111'
|
||||
|
||||
headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
user_id = 12345
|
||||
request_url = '{}/guilds/{}/members/{}/roles/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id, name_to_id.return_value)
|
||||
|
||||
djcache.get.return_value = None # No existing backoffs in cache
|
||||
|
||||
m.put(request_url,
|
||||
request_headers=headers,
|
||||
headers={'Retry-After': '200000'},
|
||||
status_code=429)
|
||||
|
||||
# Act & Assert
|
||||
with self.assertRaises(manager.DiscordApiBackoff) as bo:
|
||||
try:
|
||||
DiscordOAuthManager.update_groups(user_id, groups, blocking=False)
|
||||
except manager.DiscordApiBackoff as bo:
|
||||
self.assertEqual(bo.retry_after, 200000, 'Retry-After time must be equal to Retry-After set in header')
|
||||
self.assertFalse(bo.global_ratelimit, 'global_ratelimit must be False')
|
||||
raise bo
|
||||
|
||||
self.assertTrue(djcache.set.called)
|
||||
args, kwargs = djcache.set.call_args
|
||||
self.assertEqual(args[0], 'DISCORD_BACKOFF_update_groups')
|
||||
self.assertTrue(datetime.datetime.strptime(args[1], manager.cache_time_format) > datetime.datetime.now())
|
||||
|
||||
@mock.patch(MODULE_PATH + '.manager.cache')
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._get_user_roles')
|
||||
@mock.patch(MODULE_PATH + '.manager.DiscordOAuthManager._group_name_to_id')
|
||||
@requests_mock.Mocker()
|
||||
def test_update_groups_global_backoff(self, name_to_id, user_groups, djcache, m):
|
||||
# Arrange
|
||||
groups = ['Member']
|
||||
user_groups.return_value = []
|
||||
name_to_id.return_value = '111'
|
||||
|
||||
headers = {'content-type': 'application/json', 'authorization': 'Bot ' + settings.DISCORD_BOT_TOKEN}
|
||||
user_id = 12345
|
||||
request_url = '{}/guilds/{}/members/{}/roles/{}'.format(manager.DISCORD_URL, settings.DISCORD_GUILD_ID, user_id, name_to_id.return_value)
|
||||
|
||||
djcache.get.return_value = None # No existing backoffs in cache
|
||||
|
||||
m.put(request_url,
|
||||
request_headers=headers,
|
||||
headers={'Retry-After': '200000', 'X-RateLimit-Global': 'true'},
|
||||
status_code=429)
|
||||
|
||||
# Act & Assert
|
||||
with self.assertRaises(manager.DiscordApiBackoff) as bo:
|
||||
try:
|
||||
DiscordOAuthManager.update_groups(user_id, groups, blocking=False)
|
||||
except manager.DiscordApiBackoff as bo:
|
||||
self.assertEqual(bo.retry_after, 200000, 'Retry-After time must be equal to Retry-After set in header')
|
||||
self.assertTrue(bo.global_ratelimit, 'global_ratelimit must be True')
|
||||
raise bo
|
||||
|
||||
self.assertTrue(djcache.set.called)
|
||||
args, kwargs = djcache.set.call_args
|
||||
self.assertEqual(args[0], 'DISCORD_BACKOFF_GLOBAL')
|
||||
self.assertTrue(datetime.datetime.strptime(args[1], manager.cache_time_format) > datetime.datetime.now())
|
||||
def test_return_false_if_not_called_with_user_object(self):
|
||||
self.assertFalse(DiscordUser.objects.user_has_account('abc'))
|
||||
|
||||
222
allianceauth/services/modules/discord/tests/test_models.py
Normal file
222
allianceauth/services/modules/discord/tests/test_models.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import TEST_USER_NAME, TEST_USER_ID, TEST_MAIN_NAME, TEST_MAIN_ID, MODULE_PATH
|
||||
from ..discord_client import DiscordClient, DiscordApiBackoff
|
||||
from ..models import DiscordUser
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
|
||||
logger = set_logger_to_file(MODULE_PATH + '.models', __file__)
|
||||
|
||||
|
||||
class TestBasicsAndHelpers(TestCase):
|
||||
|
||||
def test_str(self):
|
||||
user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
discord_user = DiscordUser.objects.create(user=user, uid=TEST_USER_ID)
|
||||
expected = 'Peter Parker - 198765432012345678'
|
||||
self.assertEqual(str(discord_user), expected)
|
||||
|
||||
def test_repr(self):
|
||||
user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
discord_user = DiscordUser.objects.create(user=user, uid=TEST_USER_ID)
|
||||
expected = 'DiscordUser(user=\'Peter Parker\', uid=198765432012345678)'
|
||||
self.assertEqual(repr(discord_user), expected)
|
||||
|
||||
def test_guild_get_or_create_role_ids(self):
|
||||
mock_client = Mock(spec=DiscordClient)
|
||||
mock_client.match_guild_roles_to_names.return_value = \
|
||||
[({'id': 1, 'name': 'alpha'}, True), ({'id': 2, 'name': 'bravo'}, True)]
|
||||
|
||||
result = DiscordUser._guild_get_or_create_role_ids(mock_client, [])
|
||||
excepted = [1, 2]
|
||||
self.assertEqual(set(result), set(excepted))
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestUpdateNick(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.discord_user = DiscordUser.objects.create(
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def user_info(nick):
|
||||
return {
|
||||
'user': {
|
||||
'id': TEST_USER_ID,
|
||||
'username': TEST_USER_NAME
|
||||
},
|
||||
'nick': nick,
|
||||
'roles': [1, 2, 3]
|
||||
}
|
||||
|
||||
def test_can_update(self, mock_DiscordClient):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
result = self.discord_user.update_nickname()
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_dont_update_if_user_has_no_main(self, mock_DiscordClient):
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
|
||||
result = self.discord_user.update_nickname()
|
||||
self.assertFalse(result)
|
||||
self.assertFalse(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(
|
||||
self, mock_DiscordClient
|
||||
):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = None
|
||||
|
||||
result = self.discord_user.update_nickname()
|
||||
self.assertIsNone(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(self, mock_DiscordClient):
|
||||
AuthUtils.add_main_character_2(self.user, TEST_MAIN_NAME, TEST_MAIN_ID)
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
|
||||
result = self.discord_user.update_nickname()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.models.notify')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestDeleteUser(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.discord_user = DiscordUser.objects.create(
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
|
||||
def test_can_delete_user(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertTrue(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_can_delete_user_and_notify_user(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
result = self.discord_user.delete_user(notify_user=True)
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_notify.called)
|
||||
|
||||
def test_can_delete_user_when_member_is_unknown(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = None
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertTrue(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_return_false_when_api_fails(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dont_notify_if_user_was_already_deleted_and_return_none(
|
||||
self, mock_DiscordClient, mock_notify
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = None
|
||||
DiscordUser.objects.get(pk=self.discord_user.pk).delete()
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertIsNone(result)
|
||||
self.assertFalse(
|
||||
DiscordUser.objects.filter(user=self.user, uid=TEST_USER_ID).exists()
|
||||
)
|
||||
self.assertTrue(mock_DiscordClient.return_value.remove_guild_member.called)
|
||||
self.assertFalse(mock_notify.called)
|
||||
|
||||
def test_return_false_on_api_backoff(self, mock_DiscordClient, mock_notify):
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
DiscordApiBackoff(999)
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_return_false_on_http_error(self, mock_DiscordClient, mock_notify):
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = Mock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_DiscordClient.return_value.remove_guild_member.side_effect = \
|
||||
mock_exception
|
||||
result = self.discord_user.delete_user()
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
@patch(MODULE_PATH + '.models.DiscordUser._guild_get_or_create_role_ids')
|
||||
@patch(MODULE_PATH + '.models.DiscordUser.objects.user_group_names')
|
||||
class TestUpdateGroups(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.user = AuthUtils.create_user(TEST_USER_NAME)
|
||||
self.discord_user = DiscordUser.objects.create(
|
||||
user=self.user, uid=TEST_USER_ID
|
||||
)
|
||||
|
||||
def test_can_update(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_DiscordClient
|
||||
):
|
||||
roles_requested = [1, 2, 3]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = roles_requested
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = True
|
||||
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_none_if_user_no_longer_a_member(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_DiscordClient
|
||||
):
|
||||
roles_requested = [1, 2, 3]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = roles_requested
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = None
|
||||
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertIsNone(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
|
||||
def test_return_false_if_api_returns_false(
|
||||
self,
|
||||
mock_user_group_names,
|
||||
mock_guild_get_or_create_role_ids,
|
||||
mock_DiscordClient
|
||||
):
|
||||
roles_requested = [1, 2, 3]
|
||||
mock_user_group_names.return_value = []
|
||||
mock_guild_get_or_create_role_ids.return_value = roles_requested
|
||||
mock_DiscordClient.return_value.modify_guild_member.return_value = False
|
||||
|
||||
result = self.discord_user.update_groups()
|
||||
self.assertFalse(result)
|
||||
self.assertTrue(mock_DiscordClient.return_value.modify_guild_member.called)
|
||||
310
allianceauth/services/modules/discord/tests/test_tasks.py
Normal file
310
allianceauth/services/modules/discord/tests/test_tasks.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from celery.exceptions import Retry
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import Group
|
||||
from django.test.utils import override_settings
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from . import TEST_USER_NAME, TEST_USER_ID, TEST_MAIN_NAME, TEST_MAIN_ID
|
||||
from ..models import DiscordUser
|
||||
from ..discord_client import DiscordApiBackoff
|
||||
from .. import tasks
|
||||
from ..utils import set_logger_to_file
|
||||
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord.tasks'
|
||||
logger = set_logger_to_file(MODULE_PATH, __file__)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_groups')
|
||||
class TestUpdateGroups(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
cls.group_1 = Group.objects.create(name='Group 1')
|
||||
cls.group_2 = Group.objects.create(name='Group 2')
|
||||
cls.group_1.user_set.add(cls.user)
|
||||
cls.group_2.user_set.add(cls.user)
|
||||
|
||||
def test_can_update_groups(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
tasks.update_groups(self.user.pk)
|
||||
self.assertTrue(mock_update_groups.called)
|
||||
|
||||
def test_no_action_if_user_has_no_discord_account(self, mock_update_groups):
|
||||
tasks.update_groups(self.user.pk)
|
||||
self.assertFalse(mock_update_groups.called)
|
||||
|
||||
def test_retries_on_api_backoff(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
mock_exception = DiscordApiBackoff(999)
|
||||
mock_update_groups.side_effect = mock_exception
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_groups(self.user.pk)
|
||||
|
||||
def test_retry_on_http_error_except_404(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = MagicMock()
|
||||
mock_exception.response.status_code = 500
|
||||
mock_update_groups.side_effect = mock_exception
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_groups(self.user.pk)
|
||||
|
||||
def test_retry_on_http_error_404_when_user_not_deleted(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
mock_exception = HTTPError('error')
|
||||
mock_exception.response = MagicMock()
|
||||
mock_exception.response.status_code = 404
|
||||
mock_update_groups.side_effect = mock_exception
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_groups(self.user.pk)
|
||||
|
||||
def test_retry_on_non_http_error(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
mock_update_groups.side_effect = ConnectionError
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_groups(self.user.pk)
|
||||
|
||||
@patch(MODULE_PATH + '.DISCORD_TASKS_MAX_RETRIES', 3)
|
||||
def test_log_error_if_retries_exhausted(self, mock_update_groups):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
mock_task = MagicMock(**{'request.retries': 3})
|
||||
mock_update_groups.side_effect = ConnectionError
|
||||
update_groups_inner = tasks.update_groups.__wrapped__.__func__
|
||||
|
||||
update_groups_inner(mock_task, self.user.pk)
|
||||
|
||||
@patch(MODULE_PATH + '.delete_user.delay')
|
||||
def test_delete_user_if_user_is_no_longer_member_of_discord_server(
|
||||
self, mock_delete_user, mock_update_groups
|
||||
):
|
||||
mock_update_groups.return_value = None
|
||||
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
tasks.update_groups(self.user.pk)
|
||||
self.assertTrue(mock_update_groups.called)
|
||||
self.assertTrue(mock_delete_user.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_nickname')
|
||||
class TestUpdateNickname(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
AuthUtils.add_main_character_2(
|
||||
cls.user,
|
||||
TEST_MAIN_NAME,
|
||||
TEST_MAIN_ID,
|
||||
corp_id='2',
|
||||
corp_name='test_corp',
|
||||
corp_ticker='TEST',
|
||||
disconnect_signals=True
|
||||
)
|
||||
cls.discord_user = DiscordUser.objects.create(user=cls.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_can_update_nickname(self, mock_update_nickname):
|
||||
mock_update_nickname.return_value = True
|
||||
|
||||
tasks.update_nickname(self.user.pk)
|
||||
self.assertTrue(mock_update_nickname.called)
|
||||
|
||||
def test_no_action_when_user_had_no_account(self, mock_update_nickname):
|
||||
my_user = AuthUtils.create_user('Dummy User')
|
||||
mock_update_nickname.return_value = False
|
||||
|
||||
tasks.update_nickname(my_user.pk)
|
||||
self.assertFalse(mock_update_nickname.called)
|
||||
|
||||
def test_retries_on_api_backoff(self, mock_update_nickname):
|
||||
mock_exception = DiscordApiBackoff(999)
|
||||
mock_update_nickname.side_effect = mock_exception
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_nickname(self.user.pk)
|
||||
|
||||
def test_retries_on_general_exception(self, mock_update_nickname):
|
||||
mock_update_nickname.side_effect = ConnectionError
|
||||
|
||||
with self.assertRaises(Retry):
|
||||
tasks.update_nickname(self.user.pk)
|
||||
|
||||
@patch(MODULE_PATH + '.DISCORD_TASKS_MAX_RETRIES', 3)
|
||||
def test_log_error_if_retries_exhausted(self, mock_update_nickname):
|
||||
mock_task = MagicMock(**{'request.retries': 3})
|
||||
mock_update_nickname.side_effect = ConnectionError
|
||||
update_nickname_inner = tasks.update_nickname.__wrapped__.__func__
|
||||
|
||||
update_nickname_inner(mock_task, self.user.pk)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.delete_user')
|
||||
class TestDeleteUser(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_member('Peter Parker')
|
||||
cls.discord_user = DiscordUser.objects.create(user=cls.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_can_delete_user(self, mock_delete_user):
|
||||
mock_delete_user.return_value = True
|
||||
|
||||
tasks.delete_user(self.user.pk)
|
||||
self.assertTrue(mock_delete_user.called)
|
||||
|
||||
def test_can_delete_user_with_notify(self, mock_delete_user):
|
||||
mock_delete_user.return_value = True
|
||||
|
||||
tasks.delete_user(self.user.pk, notify_user=True)
|
||||
self.assertTrue(mock_delete_user.called)
|
||||
args, kwargs = mock_delete_user.call_args
|
||||
self.assertTrue(kwargs['notify_user'])
|
||||
|
||||
@patch(MODULE_PATH + '.delete_user.delay')
|
||||
def test_dont_retry_delete_user_if_user_is_no_longer_member_of_discord_server(
|
||||
self, mock_delete_user_delay, mock_delete_user
|
||||
):
|
||||
mock_delete_user.return_value = None
|
||||
|
||||
tasks.delete_user(self.user.pk)
|
||||
self.assertTrue(mock_delete_user.called)
|
||||
self.assertFalse(mock_delete_user_delay.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.DiscordUser.update_groups')
|
||||
class TestTaskPerformUserAction(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user = AuthUtils.create_member('Peter Parker')
|
||||
cls.discord_user = DiscordUser.objects.create(user=cls.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_raise_value_error_on_unknown_method(self, mock_update_groups):
|
||||
mock_task = MagicMock(**{'request.retries': 0})
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tasks._task_perform_user_action(mock_task, self.user.pk, 'invalid_method')
|
||||
|
||||
def test_catch_and_log_unexpected_exceptions(self, mock_update_groups):
|
||||
mock_task = MagicMock(**{'request.retries': 0})
|
||||
mock_update_groups.side_effect = RuntimeError
|
||||
|
||||
tasks._task_perform_user_action(mock_task, self.user.pk, 'update_groups')
|
||||
|
||||
|
||||
@override_settings(CELERY_ALWAYS_EAGER=True)
|
||||
class TestBulkTasks(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.user_1 = AuthUtils.create_user('Peter Parker')
|
||||
cls.user_2 = AuthUtils.create_user('Kara Danvers')
|
||||
cls.user_3 = AuthUtils.create_user('Clark Kent')
|
||||
DiscordUser.objects.all().delete()
|
||||
|
||||
@patch(MODULE_PATH + '.update_groups.si')
|
||||
def test_can_update_groups_for_multiple_users(self, mock_update_groups):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid=123)
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid=456)
|
||||
DiscordUser.objects.create(user=self.user_3, uid=789)
|
||||
expected_pks = [du_1.pk, du_2.pk]
|
||||
|
||||
tasks.update_groups_bulk(expected_pks)
|
||||
self.assertEqual(mock_update_groups.call_count, 2)
|
||||
current_pks = [args[0][0] for args in mock_update_groups.call_args_list]
|
||||
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
@patch(MODULE_PATH + '.update_groups.si')
|
||||
def test_can_update_all_groups(self, mock_update_groups):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid=123)
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid=456)
|
||||
du_3 = DiscordUser.objects.create(user=self.user_3, uid=789)
|
||||
|
||||
tasks.update_all_groups()
|
||||
self.assertEqual(mock_update_groups.call_count, 3)
|
||||
current_pks = [args[0][0] for args in mock_update_groups.call_args_list]
|
||||
expected_pks = [du_1.pk, du_2.pk, du_3.pk]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
@patch(MODULE_PATH + '.update_nickname.si')
|
||||
def test_can_update_nicknames_for_multiple_users(self, mock_update_nickname):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid=123)
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid=456)
|
||||
DiscordUser.objects.create(user=self.user_3, uid=789)
|
||||
expected_pks = [du_1.pk, du_2.pk]
|
||||
|
||||
tasks.update_nicknames_bulk(expected_pks)
|
||||
self.assertEqual(mock_update_nickname.call_count, 2)
|
||||
current_pks = [
|
||||
args[0][0] for args in mock_update_nickname.call_args_list
|
||||
]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
@patch(MODULE_PATH + '.update_nickname.si')
|
||||
def test_can_update_nicknames_for_all_users(self, mock_update_nickname):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid='123')
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid='456')
|
||||
du_3 = DiscordUser.objects.create(user=self.user_3, uid='789')
|
||||
|
||||
tasks.update_all_nicknames()
|
||||
self.assertEqual(mock_update_nickname.call_count, 3)
|
||||
current_pks = [
|
||||
args[0][0] for args in mock_update_nickname.call_args_list
|
||||
]
|
||||
expected_pks = [du_1.pk, du_2.pk, du_3.pk]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
@patch(MODULE_PATH + '.DISCORD_SYNC_NAMES', True)
|
||||
@patch(MODULE_PATH + '.update_nickname')
|
||||
@patch(MODULE_PATH + '.update_groups')
|
||||
def test_can_update_all_incl_nicknames(
|
||||
self, mock_update_groups, mock_update_nickname
|
||||
):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid=123)
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid=456)
|
||||
du_3 = DiscordUser.objects.create(user=self.user_3, uid=789)
|
||||
|
||||
tasks.update_all()
|
||||
self.assertEqual(mock_update_groups.si.call_count, 3)
|
||||
current_pks = [args[0][0] for args in mock_update_groups.si.call_args_list]
|
||||
expected_pks = [du_1.pk, du_2.pk, du_3.pk]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
self.assertEqual(mock_update_nickname.si.call_count, 3)
|
||||
current_pks = [args[0][0] for args in mock_update_nickname.si.call_args_list]
|
||||
expected_pks = [du_1.pk, du_2.pk, du_3.pk]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
@patch(MODULE_PATH + '.DISCORD_SYNC_NAMES', False)
|
||||
@patch(MODULE_PATH + '.update_nickname')
|
||||
@patch(MODULE_PATH + '.update_groups')
|
||||
def test_can_update_all_excl_nicknames(
|
||||
self, mock_update_groups, mock_update_nickname
|
||||
):
|
||||
du_1 = DiscordUser.objects.create(user=self.user_1, uid=123)
|
||||
du_2 = DiscordUser.objects.create(user=self.user_2, uid=456)
|
||||
du_3 = DiscordUser.objects.create(user=self.user_3, uid=789)
|
||||
|
||||
tasks.update_all()
|
||||
self.assertEqual(mock_update_groups.si.call_count, 3)
|
||||
current_pks = [args[0][0] for args in mock_update_groups.si.call_args_list]
|
||||
expected_pks = [du_1.pk, du_2.pk, du_3.pk]
|
||||
self.assertSetEqual(set(current_pks), set(expected_pks))
|
||||
|
||||
self.assertEqual(mock_update_nickname.si.call_count, 0)
|
||||
102
allianceauth/services/modules/discord/tests/test_utils.py
Normal file
102
allianceauth/services/modules/discord/tests/test_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from django.test import TestCase
|
||||
|
||||
from ..utils import clean_setting
|
||||
|
||||
MODULE_PATH = 'allianceauth.services.modules.discord.utils'
|
||||
|
||||
|
||||
class TestCleanSetting(TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_if_not_set(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = Mock(spec=None)
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
False,
|
||||
)
|
||||
self.assertEqual(result, False)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_if_not_set_for_none(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = Mock(spec=None)
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
None,
|
||||
required_type=int
|
||||
)
|
||||
self.assertEqual(result, None)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_true_stays_true(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = True
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
False,
|
||||
)
|
||||
self.assertEqual(result, True)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_false_stays_false(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = False
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
False
|
||||
)
|
||||
self.assertEqual(result, False)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_for_invalid_type_bool(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = 'invalid type'
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
False
|
||||
)
|
||||
self.assertEqual(result, False)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_for_invalid_type_int(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = 'invalid type'
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
50
|
||||
)
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_if_below_minimum_1(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = -5
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
default_value=50
|
||||
)
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_if_below_minimum_2(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = -50
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
default_value=50,
|
||||
min_value=-10
|
||||
)
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_for_invalid_type_int_2(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = 1000
|
||||
result = clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
default_value=50,
|
||||
max_value=100
|
||||
)
|
||||
self.assertEqual(result, 50)
|
||||
|
||||
@patch(MODULE_PATH + '.settings')
|
||||
def test_default_is_none_needs_required_type(self, mock_settings):
|
||||
mock_settings.TEST_SETTING_DUMMY = 'invalid type'
|
||||
with self.assertRaises(ValueError):
|
||||
clean_setting(
|
||||
'TEST_SETTING_DUMMY',
|
||||
default_value=None
|
||||
)
|
||||
@@ -1,71 +1,167 @@
|
||||
from django_webtest import WebTest
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.conf import settings
|
||||
from django.test import TestCase, RequestFactory
|
||||
from django.urls import reverse
|
||||
|
||||
from allianceauth.tests.auth_utils import AuthUtils
|
||||
|
||||
from ..models import DiscordUser
|
||||
from ..manager import DiscordOAuthManager
|
||||
|
||||
from . import DEFAULT_AUTH_GROUP, add_permissions, MODULE_PATH
|
||||
from . import MODULE_PATH, add_permissions_to_members, TEST_USER_NAME, TEST_USER_ID
|
||||
from ..models import DiscordUser, DiscordClient
|
||||
from ..utils import set_logger_to_file
|
||||
from ..views import (
|
||||
discord_callback,
|
||||
reset_discord,
|
||||
deactivate_discord,
|
||||
discord_add_bot,
|
||||
activate_discord
|
||||
)
|
||||
|
||||
|
||||
class DiscordViewsTestCase(WebTest):
|
||||
logger = set_logger_to_file(MODULE_PATH + '.views', __file__)
|
||||
|
||||
|
||||
class SetupClassMixin(TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.factory = RequestFactory()
|
||||
cls.user = AuthUtils.create_member(TEST_USER_NAME)
|
||||
add_permissions_to_members()
|
||||
cls.services_url = reverse('services:services')
|
||||
|
||||
|
||||
class TestActivateDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_oauth_redirect_url')
|
||||
def test_redirects_to_correct_url(self, mock_generate_oauth_redirect_url):
|
||||
expected_url = '/example.com/oauth/'
|
||||
mock_generate_oauth_redirect_url.return_value = expected_url
|
||||
request = self.factory.get(reverse('discord:activate'))
|
||||
request.user = self.user
|
||||
response = activate_discord(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, expected_url)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient', spec=DiscordClient)
|
||||
class TestDeactivateDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.member = AuthUtils.create_member('auth_member')
|
||||
AuthUtils.add_main_character(self.member, 'test character', '1234', '2345', 'test corp', 'testc')
|
||||
add_permissions()
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_when_successful_show_success_message(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
request = self.factory.get(reverse('discord:deactivate'))
|
||||
request.user = self.user
|
||||
response = deactivate_discord(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertTrue(mock_messages.success.called)
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
def login(self):
|
||||
self.app.set_user(self.member)
|
||||
def test_when_unsuccessful_show_error_message(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
request = self.factory.get(reverse('discord:deactivate'))
|
||||
request.user = self.user
|
||||
response = deactivate_discord(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertFalse(mock_messages.success.called)
|
||||
self.assertTrue(mock_messages.error.called)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.views.DiscordOAuthManager')
|
||||
def test_activate(self, manager):
|
||||
self.login()
|
||||
manager.generate_oauth_redirect_url.return_value = '/example.com/oauth/'
|
||||
response = self.app.get('/discord/activate/', auto_follow=False)
|
||||
self.assertRedirects(
|
||||
response,
|
||||
expected_url="/example.com/oauth/",
|
||||
target_status_code=404,
|
||||
fetch_redirect_response=False,
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.managers.DiscordClient')
|
||||
class TestResetDiscord(SetupClassMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_when_successful_redirect_to_activate(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = True
|
||||
request = self.factory.get(reverse('discord:reset'))
|
||||
request.user = self.user
|
||||
response = reset_discord(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, reverse("discord:activate"))
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
def test_when_unsuccessful_message_error_and_redirect_to_service(
|
||||
self, mock_DiscordClient, mock_messages
|
||||
):
|
||||
mock_DiscordClient.return_value.remove_guild_member.return_value = False
|
||||
request = self.factory.get(reverse('discord:reset'))
|
||||
request.user = self.user
|
||||
response = reset_discord(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertTrue(mock_messages.error.called)
|
||||
|
||||
|
||||
@patch(MODULE_PATH + '.views.messages')
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.add_user')
|
||||
class TestDiscordCallback(SetupClassMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
DiscordUser.objects.create(user=self.user, uid=TEST_USER_ID)
|
||||
|
||||
def test_success_message_when_ok(self, mock_add_user, mock_messages):
|
||||
mock_add_user.return_value = True
|
||||
request = self.factory.get(
|
||||
reverse('discord:callback'), data={'code': '1234'}
|
||||
)
|
||||
request.user = self.user
|
||||
response = discord_callback(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertTrue(mock_messages.success.called)
|
||||
self.assertFalse(mock_messages.error.called)
|
||||
|
||||
def test_handle_no_code(self, mock_add_user, mock_messages):
|
||||
mock_add_user.return_value = True
|
||||
request = self.factory.get(
|
||||
reverse('discord:callback'), data={}
|
||||
)
|
||||
request.user = self.user
|
||||
response = discord_callback(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertFalse(mock_messages.success.called)
|
||||
self.assertTrue(mock_messages.error.called)
|
||||
|
||||
def test_error_message_when_user_creation_failed(
|
||||
self, mock_add_user, mock_messages
|
||||
):
|
||||
mock_add_user.return_value = False
|
||||
request = self.factory.get(
|
||||
reverse('discord:callback'), data={'code': '1234'}
|
||||
)
|
||||
request.user = self.user
|
||||
response = discord_callback(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, self.services_url)
|
||||
self.assertFalse(mock_messages.success.called)
|
||||
self.assertTrue(mock_messages.error.called)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_callback(self, manager):
|
||||
self.login()
|
||||
manager.add_user.return_value = '1234'
|
||||
response = self.app.get('/discord/callback/', params={'code': '1234'})
|
||||
|
||||
self.member = User.objects.get(pk=self.member.pk)
|
||||
|
||||
self.assertTrue(manager.add_user.called)
|
||||
self.assertEqual(manager.update_nickname.called, settings.DISCORD_SYNC_NAMES)
|
||||
self.assertEqual(self.member.discord.uid, '1234')
|
||||
self.assertRedirects(response, expected_url='/services/', target_status_code=200)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_reset(self, manager):
|
||||
self.login()
|
||||
DiscordUser.objects.create(user=self.member, uid='12345')
|
||||
manager.delete_user.return_value = True
|
||||
|
||||
response = self.app.get('/discord/reset/')
|
||||
|
||||
self.assertRedirects(response, expected_url='/discord/activate/', target_status_code=302)
|
||||
|
||||
@mock.patch(MODULE_PATH + '.tasks.DiscordOAuthManager')
|
||||
def test_deactivate(self, manager):
|
||||
self.login()
|
||||
DiscordUser.objects.create(user=self.member, uid='12345')
|
||||
manager.delete_user.return_value = True
|
||||
|
||||
response = self.app.get('/discord/deactivate/')
|
||||
|
||||
self.assertTrue(manager.delete_user.called)
|
||||
self.assertRedirects(response, expected_url='/services/', target_status_code=200)
|
||||
with self.assertRaises(ObjectDoesNotExist):
|
||||
discord_user = User.objects.get(pk=self.member.pk).discord
|
||||
@patch(MODULE_PATH + '.views.DiscordUser.objects.generate_bot_add_url')
|
||||
class TestDiscordAddBot(TestCase):
|
||||
|
||||
def test_add_bot(self, mock_generate_bot_add_url):
|
||||
bot_url = 'https://www.example.com/bot'
|
||||
mock_generate_bot_add_url.return_value = bot_url
|
||||
my_user = User.objects.create_superuser('Lex Luthor', 'abc', 'def')
|
||||
request = RequestFactory().get(reverse('discord:add_bot'))
|
||||
request.user = my_user
|
||||
response = discord_add_bot(request)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertEqual(response.url, bot_url)
|
||||
|
||||
89
allianceauth/services/modules/discord/utils.py
Normal file
89
allianceauth/services/modules/discord/utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggerAddTag(logging.LoggerAdapter):
|
||||
"""add custom tag to a logger"""
|
||||
def __init__(self, logger, prefix):
|
||||
super(LoggerAddTag, self).__init__(logger, {})
|
||||
self.prefix = prefix
|
||||
|
||||
def process(self, msg, kwargs):
|
||||
return '[%s] %s' % (self.prefix, msg), kwargs
|
||||
|
||||
|
||||
def clean_setting(
|
||||
name: str,
|
||||
default_value: object,
|
||||
min_value: int = None,
|
||||
max_value: int = None,
|
||||
required_type: type = None
|
||||
):
|
||||
"""cleans the input for a custom setting
|
||||
|
||||
Will use `default_value` if settings does not exit or has the wrong type
|
||||
or is outside define boundaries (for int only)
|
||||
|
||||
Need to define `required_type` if `default_value` is `None`
|
||||
|
||||
Will assume `min_value` of 0 for int (can be overriden)
|
||||
|
||||
Returns cleaned value for setting
|
||||
"""
|
||||
if default_value is None and not required_type:
|
||||
raise ValueError('You must specify a required_type for None defaults')
|
||||
|
||||
if not required_type:
|
||||
required_type = type(default_value)
|
||||
|
||||
if min_value is None and required_type == int:
|
||||
min_value = 0
|
||||
|
||||
if not hasattr(settings, name):
|
||||
cleaned_value = default_value
|
||||
else:
|
||||
if (
|
||||
isinstance(getattr(settings, name), required_type)
|
||||
and (min_value is None or getattr(settings, name) >= min_value)
|
||||
and (max_value is None or getattr(settings, name) <= max_value)
|
||||
):
|
||||
cleaned_value = getattr(settings, name)
|
||||
else:
|
||||
logger.warning(
|
||||
'You setting for %s it not valid. Please correct it. '
|
||||
'Using default for now: %s',
|
||||
name,
|
||||
default_value
|
||||
)
|
||||
cleaned_value = default_value
|
||||
return cleaned_value
|
||||
|
||||
|
||||
def set_logger_to_file(logger_name: str, name: str) -> object:
|
||||
"""set logger for current module to log into a file. Useful for tests.
|
||||
|
||||
Args:
|
||||
- logger: current logger object
|
||||
- name: name of current module, e.g. __file__
|
||||
|
||||
Returns:
|
||||
- amended logger
|
||||
"""
|
||||
|
||||
# reconfigure logger so we get logging from tested module
|
||||
f_format = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(module)s:%(funcName)s - %(message)s'
|
||||
)
|
||||
path = os.path.splitext(name)[0]
|
||||
f_handler = logging.FileHandler('{}.log'.format(path), 'w+')
|
||||
f_handler.setFormatter(f_format)
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.level = logging.DEBUG
|
||||
logger.addHandler(f_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
@@ -9,10 +9,12 @@ from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from allianceauth.services.views import superuser_test
|
||||
|
||||
from .manager import DiscordOAuthManager
|
||||
from .tasks import DiscordTasks
|
||||
from . import __title__
|
||||
from .models import DiscordUser
|
||||
from .utils import LoggerAddTag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = LoggerAddTag(logging.getLogger(__name__), __title__)
|
||||
|
||||
ACCESS_PERM = 'discord.access_discord'
|
||||
|
||||
@@ -20,53 +22,94 @@ ACCESS_PERM = 'discord.access_discord'
|
||||
@login_required
|
||||
@permission_required(ACCESS_PERM)
|
||||
def deactivate_discord(request):
|
||||
logger.debug("deactivate_discord called by user %s" % request.user)
|
||||
if DiscordTasks.delete_user(request.user):
|
||||
logger.info("Successfully deactivated discord for user %s" % request.user)
|
||||
logger.debug("deactivate_discord called by user %s", request.user)
|
||||
if request.user.discord.delete_user(is_rate_limited=False):
|
||||
logger.info("Successfully deactivated discord for user %s", request.user)
|
||||
messages.success(request, _('Deactivated Discord account.'))
|
||||
else:
|
||||
logger.error("Unsuccessful attempt to deactivate discord for user %s" % request.user)
|
||||
messages.error(request, _('An error occurred while processing your Discord account.'))
|
||||
logger.error(
|
||||
"Unsuccessful attempt to deactivate discord for user %s", request.user
|
||||
)
|
||||
messages.error(
|
||||
request, _('An error occurred while processing your Discord account.')
|
||||
)
|
||||
return redirect("services:services")
|
||||
|
||||
|
||||
@login_required
|
||||
@permission_required(ACCESS_PERM)
|
||||
def reset_discord(request):
|
||||
logger.debug("reset_discord called by user %s" % request.user)
|
||||
if DiscordTasks.delete_user(request.user):
|
||||
logger.info("Successfully deleted discord user for user %s - forwarding to discord activation." % request.user)
|
||||
logger.debug("reset_discord called by user %s", request.user)
|
||||
if request.user.discord.delete_user(is_rate_limited=False):
|
||||
logger.info(
|
||||
"Successfully deleted discord user for user %s - "
|
||||
"forwarding to discord activation.",
|
||||
request.user
|
||||
)
|
||||
return redirect("discord:activate")
|
||||
logger.error("Unsuccessful attempt to reset discord for user %s" % request.user)
|
||||
messages.error(request, _('An error occurred while processing your Discord account.'))
|
||||
|
||||
logger.error(
|
||||
"Unsuccessful attempt to reset discord for user %s", request.user
|
||||
)
|
||||
messages.error(
|
||||
request, _('An error occurred while processing your Discord account.')
|
||||
)
|
||||
return redirect("services:services")
|
||||
|
||||
|
||||
@login_required
|
||||
@permission_required(ACCESS_PERM)
|
||||
def activate_discord(request):
|
||||
logger.debug("activate_discord called by user %s" % request.user)
|
||||
return redirect(DiscordOAuthManager.generate_oauth_redirect_url())
|
||||
logger.debug("activate_discord called by user %s", request.user)
|
||||
return redirect(DiscordUser.objects.generate_oauth_redirect_url())
|
||||
|
||||
|
||||
@login_required
|
||||
@permission_required(ACCESS_PERM)
|
||||
def discord_callback(request):
|
||||
logger.debug("Received Discord callback for activation of user %s" % request.user)
|
||||
code = request.GET.get('code', None)
|
||||
if not code:
|
||||
logger.warn("Did not receive OAuth code from callback of user %s" % request.user)
|
||||
return redirect("services:services")
|
||||
if DiscordTasks.add_user(request.user, code):
|
||||
logger.info("Successfully activated Discord for user %s" % request.user)
|
||||
messages.success(request, _('Activated Discord account.'))
|
||||
logger.debug(
|
||||
"Received Discord callback for activation of user %s", request.user
|
||||
)
|
||||
authorization_code = request.GET.get('code', None)
|
||||
if not authorization_code:
|
||||
logger.warning(
|
||||
"Did not receive OAuth code from callback for user %s", request.user
|
||||
)
|
||||
success = False
|
||||
else:
|
||||
logger.error("Failed to activate Discord for user %s" % request.user)
|
||||
messages.error(request, _('An error occurred while processing your Discord account.'))
|
||||
if DiscordUser.objects.add_user(
|
||||
user=request.user,
|
||||
authorization_code=authorization_code,
|
||||
is_rate_limited=False
|
||||
):
|
||||
logger.info(
|
||||
"Successfully activated Discord account for user %s", request.user
|
||||
)
|
||||
success = True
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to activate Discord account for user %s", request.user
|
||||
)
|
||||
success = False
|
||||
|
||||
if success:
|
||||
messages.success(
|
||||
request, _('Your Discord account has been successfully activated.')
|
||||
)
|
||||
else:
|
||||
messages.error(
|
||||
request,
|
||||
_(
|
||||
'An error occurred while trying to activate your Discord account. '
|
||||
'Please try again.'
|
||||
)
|
||||
)
|
||||
|
||||
return redirect("services:services")
|
||||
|
||||
|
||||
@login_required
|
||||
@user_passes_test(superuser_test)
|
||||
def discord_add_bot(request):
|
||||
return redirect(DiscordOAuthManager.generate_bot_add_url())
|
||||
return redirect(DiscordUser.objects.generate_bot_add_url())
|
||||
|
||||
Reference in New Issue
Block a user