import logging
from typing import Literal, Any, NamedTuple
from unittest import IsolatedAsyncioTestCase
import re
from httpx import Response, AsyncClient
BASE_URL = 'http://localhost:8000'
log = logging.getLogger(__name__)
async def request(method: Literal['GET', 'POST', 'PUT', 'DELETE'] = 'GET',
path: str = '/',
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
cookies: dict[str, Any] | None = None,
verbose: bool = False,
) -> Response:
if verbose:
log.warning(f'Request: {locals()}')
async with AsyncClient(base_url=BASE_URL) as client:
return await client.request(
method=method, url=path, data=data,
cookies=cookies, headers=headers)
class CSRF(NamedTuple):
header: str
token: str
async def get_csrf(url: str) -> CSRF:
pattern = r'<input type="hidden" name="csrfmiddlewaretoken" value="([a-zA-Z0-9]+)">'
resp = await request('GET', url)
if csrf := re.search(pattern, resp.text):
return CSRF(header=resp.cookies['csrftoken'], token=csrf.group(1))
class IndexTest(IsolatedAsyncioTestCase):
async def test_index_redirects_to_login_page(self):
resp = await request('GET', '/')
self.assertEqual(resp.status_code, 302)
self.assertEqual(resp.next_request.url.path, '/login/')
self.assertEqual(resp.headers['Content-Type'], 'text/html; charset=utf-8')
class LoginTest(IsolatedAsyncioTestCase):
async def test_login_has_form(self):
resp = await request('GET', '/login/')
self.assertIn('input type="text" name="username"', resp.text)
self.assertIn('input type="password" name="password"', resp.text)
self.assertIn('button type="submit"', resp.text)
async def test_login_csrf_required(self):
resp = await request('POST', '/login/', data={'username': 'admin', 'password': 'valid'})
self.assertEqual(resp.status_code, 403)
self.assertIn('<title>403 Forbidden</title>', resp.text)
self.assertIn('CSRF verification failed.', resp.text)
async def test_login_csrf_exists(self):
csrf = await get_csrf('/login/')
self.assertIsNotNone(csrf.header)
self.assertIsNotNone(csrf.token)
self.assertRegex(csrf.header, r'^[a-zA-Z0-9]+$')
self.assertRegex(csrf.token, r'^[a-zA-Z0-9]+$')
async def test_login_failed(self):
csrf = await get_csrf('/login/')
resp = await request(
method='POST',
path='/login/',
data={'username': 'admin', 'password': 'invalid', 'csrfmiddlewaretoken': csrf.token},
cookies={'csrftoken': csrf.header})
msg = 'Please enter the correct username and password for a staff account.'
self.assertIn(msg, resp.text)
async def test_login_success(self):
csrf = await get_csrf('/login/')
resp = await request(
method='POST',
path='/login/',
data={'username': 'admin', 'password': 'valid', 'csrfmiddlewaretoken': csrf.token},
cookies={'csrftoken': csrf.header})
msg = 'Please enter the correct username and password for a staff account.'
self.assertNotIn(msg, resp.text)
self.assertEqual(resp.status_code, 302)
self.assertEqual(resp.next_request.url.path, '/accounts/profile/')