测试在软件开发中扮演着至关重要的角色:
在代码部署前发现问题,降低修复成本
修改代码时确保功能不受影响
测试代码是最好的功能文档
| 框架 | 特点 | 安装命令 | 适用场景 |
|---|---|---|---|
unittest |
Python标准库,无需额外安装 | 内置 |
基础测试,简单项目 |
pytest |
功能强大,语法简洁,插件丰富 | pip install pytest |
大多数Flask项目首选 |
nose2 |
unittest扩展,兼容unittest | pip install nose2 |
从unittest迁移 |
# 安装pytest和Flask测试扩展
pip install pytest pytest-flask pytest-cov
# 用于模拟外部依赖
pip install pytest-mock
# 数据库测试
pip install pytest-flask-sqlalchemy
# 可选:增强断言信息
pip install pytest-assert-utils
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
--verbose
--tb=short
--strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
integration: integration tests
e2e: end-to-end tests
# tests/conftest.py
import pytest
import sys
import os
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from app import create_app
from app.extensions import db
@pytest.fixture(scope='session')
def app():
"""创建测试应用实例"""
# 使用测试配置
app = create_app({
'TESTING': True,
'SQLALCHEMY_DATABASE_URI': 'sqlite:///:memory:',
'SQLALCHEMY_TRACK_MODIFICATIONS': False,
'SECRET_KEY': 'test-secret-key',
'WTF_CSRF_ENABLED': False # 禁用CSRF便于测试
})
return app
@pytest.fixture(scope='function')
def client(app):
"""创建测试客户端"""
with app.test_client() as client:
yield client
@pytest.fixture(scope='function')
def db_session(app):
"""创建数据库会话"""
with app.app_context():
db.create_all()
yield db
db.session.remove()
db.drop_all()
@pytest.fixture(scope='function')
def runner(app):
"""创建CLI运行器"""
return app.test_cli_runner()
@pytest.fixture(scope='session', autouse=True)
def setup_test_environment():
"""设置测试环境"""
# 测试前执行
print("\n=== 设置测试环境 ===")
yield
# 测试后执行
print("\n=== 清理测试环境 ===")
# tests/test_basic.py
import pytest
def test_index_route(client):
"""测试首页是否正常返回"""
response = client.get('/')
# 断言状态码
assert response.status_code == 200
# 断言响应内容
assert b'Welcome' in response.data or b'Home' in response.data
# 断言内容类型
assert 'text/html' in response.content_type
def test_404_error(client):
"""测试不存在的页面返回404"""
response = client.get('/nonexistent-page')
assert response.status_code == 404
def test_redirect(client):
"""测试重定向"""
response = client.get('/old-url')
# 断言重定向
assert response.status_code == 302
assert '/new-url' in response.location
# 跟随重定向
response = client.get('/old-url', follow_redirects=True)
assert response.status_code == 200
# tests/test_config.py
import pytest
def test_app_config(app):
"""测试应用配置"""
assert app.config['TESTING'] == True
assert app.config['SECRET_KEY'] == 'test-secret-key'
assert app.config['SQLALCHEMY_DATABASE_URI'] == 'sqlite:///:memory:'
assert app.config['WTF_CSRF_ENABLED'] == False
def test_app_debug_mode(app):
"""测试调试模式"""
# 生产环境应该关闭调试模式
if app.config.get('ENVIRONMENT') == 'production':
assert app.debug == False
assert app.config['TESTING'] == False
else:
# 开发环境可以开启调试
pass
def test_app_extensions(app):
"""测试扩展是否正常加载"""
assert hasattr(app, 'extensions')
assert 'sqlalchemy' in app.extensions
# 检查数据库扩展
from app.extensions import db
assert db.engine.url.database == ':memory:' # 内存数据库
# tests/test_views.py
import json
def test_get_user_profile(client):
"""测试获取用户资料"""
# 模拟登录
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})
response = client.get('/profile')
assert response.status_code == 200
assert b'Profile' in response.data
assert b'testuser' in response.data
def test_get_with_query_params(client):
"""测试带查询参数的GET请求"""
response = client.get('/search?q=flask&page=1')
assert response.status_code == 200
assert b'flask' in response.data.lower()
# 检查查询参数
assert b'page 1' in response.data
def test_get_json_api(client):
"""测试返回JSON的API端点"""
response = client.get('/api/users/1')
assert response.status_code == 200
assert response.content_type == 'application/json'
data = json.loads(response.data)
# 断言JSON结构
assert 'id' in data
assert 'name' in data
assert 'email' in data
# 断言具体值
assert data['id'] == 1
assert isinstance(data['name'], str)
# tests/test_auth.py
def test_user_registration_success(client, db_session):
"""测试用户注册成功"""
data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'securepassword123',
'confirm_password': 'securepassword123'
}
response = client.post('/register',
data=data,
follow_redirects=True)
# 断言注册成功
assert response.status_code == 200
assert b'Registration successful' in response.data
assert b'Welcome' in response.data
# 检查用户是否被创建
from app.models import User
user = User.query.filter_by(username='newuser').first()
assert user is not None
assert user.email == 'newuser@example.com'
def test_user_registration_failure(client):
"""测试用户注册失败情况"""
# 测试密码不匹配
data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'password123',
'confirm_password': 'differentpassword'
}
response = client.post('/register', data=data)
assert response.status_code == 200 # 应该停留在注册页面
assert b'Passwords do not match' in response.data
# 测试邮箱已存在
data = {
'username': 'anotheruser',
'email': 'existing@example.com', # 假设这个邮箱已存在
'password': 'password123',
'confirm_password': 'password123'
}
response = client.post('/register', data=data)
assert b'Email already registered' in response.data
# 测试缺少必要字段
response = client.post('/register', data={'username': 'incomplete'})
assert b'This field is required' in response.data
# tests/test_forms.py
import pytest
from app.forms import RegistrationForm, LoginForm
from wtforms.validators import ValidationError
def test_registration_form_valid():
"""测试注册表单验证通过"""
form = RegistrationForm(data={
'username': 'validuser',
'email': 'valid@example.com',
'password': 'ValidPass123',
'confirm_password': 'ValidPass123'
})
assert form.validate() == True
assert len(form.errors) == 0
def test_registration_form_invalid():
"""测试注册表单验证失败"""
# 弱密码
form = RegistrationForm(data={
'username': 'test',
'email': 'test@example.com',
'password': '123',
'confirm_password': '123'
})
assert form.validate() == False
assert 'password' in form.errors
assert len(form.errors['password']) > 0
# 邮箱格式错误
form = RegistrationForm(data={
'username': 'test',
'email': 'invalid-email',
'password': 'Password123',
'confirm_password': 'Password123'
})
assert form.validate() == False
assert 'email' in form.errors
# 密码不匹配
form = RegistrationForm(data={
'username': 'test',
'email': 'test@example.com',
'password': 'Password123',
'confirm_password': 'Different123'
})
assert form.validate() == False
assert 'confirm_password' in form.errors
@pytest.mark.parametrize('username,expected', [
('validname', True),
('a', False), # 太短
('username_with_underscore', True),
('user@name', False), # 包含特殊字符
('toolongusername12345', False), # 太长
])
def test_username_validation(username, expected):
"""参数化测试用户名验证"""
form = RegistrationForm(data={
'username': username,
'email': 'test@example.com',
'password': 'Password123',
'confirm_password': 'Password123'
})
# 只验证用户名字段
form.validate_username(form.username)
if expected:
assert 'username' not in form.errors
else:
assert 'username' in form.errors
# tests/test_models.py
import pytest
from datetime import datetime, timedelta
from app.models import User, Post, db
def test_user_creation(db_session):
"""测试用户模型创建"""
user = User(
username='testuser',
email='test@example.com',
password_hash='hashed_password'
)
db_session.session.add(user)
db_session.session.commit()
# 验证用户属性
assert user.id is not None
assert user.username == 'testuser'
assert user.email == 'test@example.com'
assert user.created_at is not None
assert user.is_active == True
# 测试字符串表示
assert str(user) == f'<User testuser>'
# 测试密码验证
user.set_password('mypassword')
assert user.check_password('mypassword') == True
assert user.check_password('wrongpassword') == False
def test_user_relationships(db_session):
"""测试用户关系(一对多)"""
user = User(username='author', email='author@example.com')
# 创建关联的帖子
post1 = Post(title='First Post', content='Content 1', author=user)
post2 = Post(title='Second Post', content='Content 2', author=user)
db_session.session.add_all([user, post1, post2])
db_session.session.commit()
# 验证关系
assert len(user.posts) == 2
assert post1.author == user
assert post2.author == user
assert post1 in user.posts
assert post2 in user.posts
def test_post_timestamps(db_session):
"""测试帖子的时间戳"""
user = User(username='test', email='test@example.com')
post = Post(title='Test Post', content='Content', author=user)
before_creation = datetime.utcnow()
db_session.session.add(post)
db_session.session.commit()
after_creation = datetime.utcnow()
# 验证创建时间
assert post.created_at >= before_creation
assert post.created_at <= after_creation
# 测试更新时间
original_updated_at = post.updated_at
# 等待一小段时间
import time
time.sleep(0.01)
# 更新帖子
post.content = 'Updated content'
db_session.session.commit()
assert post.updated_at > original_updated_at
@pytest.mark.parametrize('username,email,expected', [
('user1', 'user1@example.com', True), # 有效
('', 'user@example.com', False), # 空用户名
('user', '', False), # 空邮箱
('user', 'invalid-email', False), # 无效邮箱
])
def test_user_validation(username, email, expected, db_session):
"""测试用户模型验证"""
user = User(username=username, email=email)
if expected:
db_session.session.add(user)
db_session.session.commit()
assert user.id is not None
else:
with pytest.raises(Exception):
db_session.session.add(user)
db_session.session.commit()
db_session.session.rollback()
# tests/test_crud.py
import pytest
from app.models import User, Post, Comment
from app.services import UserService, PostService
class TestUserService:
"""测试用户服务"""
def test_create_user(self, db_session):
"""测试创建用户"""
service = UserService(db_session)
user_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'password123'
}
user = service.create_user(**user_data)
assert user.id is not None
assert user.username == 'newuser'
assert user.email == 'newuser@example.com'
assert user.check_password('password123')
# 验证用户已保存到数据库
db_user = User.query.get(user.id)
assert db_user == user
def test_get_user_by_email(self, db_session):
"""测试通过邮箱获取用户"""
# 先创建用户
user = User(
username='testuser',
email='test@example.com'
)
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
service = UserService(db_session)
found_user = service.get_user_by_email('test@example.com')
assert found_user is not None
assert found_user.id == user.id
assert found_user.username == 'testuser'
# 测试不存在的邮箱
not_found = service.get_user_by_email('nonexistent@example.com')
assert not_found is None
def test_update_user(self, db_session):
"""测试更新用户"""
user = User(
username='oldname',
email='old@example.com'
)
db_session.session.add(user)
db_session.session.commit()
service = UserService(db_session)
# 更新用户信息
updated_user = service.update_user(
user.id,
username='newname',
email='new@example.com'
)
assert updated_user.username == 'newname'
assert updated_user.email == 'new@example.com'
# 验证数据库已更新
db_user = User.query.get(user.id)
assert db_user.username == 'newname'
def test_delete_user(self, db_session):
"""测试删除用户"""
user = User(
username='todelete',
email='delete@example.com'
)
db_session.session.add(user)
db_session.session.commit()
user_id = user.id
service = UserService(db_session)
result = service.delete_user(user_id)
assert result == True
# 验证用户已删除
deleted_user = User.query.get(user_id)
assert deleted_user is None
# 测试删除不存在的用户
result = service.delete_user(99999)
assert result == False
class TestPostService:
"""测试帖子服务"""
def test_create_post(self, db_session):
"""测试创建帖子"""
user = User(username='author', email='author@example.com')
db_session.session.add(user)
db_session.session.commit()
service = PostService(db_session)
post_data = {
'title': 'Test Post',
'content': 'This is a test post.',
'author_id': user.id
}
post = service.create_post(**post_data)
assert post.id is not None
assert post.title == 'Test Post'
assert post.content == 'This is a test post.'
assert post.author_id == user.id
assert post.is_published == True # 默认发布
def test_get_published_posts(self, db_session):
"""测试获取已发布的帖子"""
user = User(username='author', email='author@example.com')
# 创建多个帖子
posts = [
Post(title='Published 1', content='Content 1', author=user, is_published=True),
Post(title='Published 2', content='Content 2', author=user, is_published=True),
Post(title='Draft', content='Draft content', author=user, is_published=False),
]
db_session.session.add_all([user] + posts)
db_session.session.commit()
service = PostService(db_session)
published_posts = service.get_published_posts()
assert len(published_posts) == 2
assert all(post.is_published for post in published_posts)
# 测试分页
paginated = service.get_published_posts(page=1, per_page=1)
assert len(paginated.items) == 1
assert paginated.total == 2
assert paginated.pages == 2
def test_search_posts(self, db_session):
"""测试搜索帖子"""
user = User(username='author', email='author@example.com')
posts = [
Post(title='Python Tutorial', content='Learn Python', author=user),
Post(title='Flask Guide', content='Flask web framework', author=user),
Post(title='Django vs Flask', content='Comparison', author=user),
]
db_session.session.add_all([user] + posts)
db_session.session.commit()
service = PostService(db_session)
# 搜索包含'Flask'的帖子
flask_posts = service.search_posts('Flask')
assert len(flask_posts) == 2 # 标题和内容中都包含
# 搜索不存在的关键词
no_results = service.search_posts('Nonexistent')
assert len(no_results) == 0
# 测试空搜索返回所有帖子
all_posts = service.search_posts('')
assert len(all_posts) == 3
# tests/test_external_api.py
import pytest
from unittest.mock import Mock, patch, MagicMock
import requests
def test_external_api_success(mocker):
"""测试外部API调用成功"""
# 模拟requests.get返回成功响应
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
'data': {
'id': 1,
'name': 'John Doe',
'email': 'john@example.com'
}
}
# 使用patch替换requests.get
with patch('requests.get', return_value=mock_response) as mock_get:
from app.services import ExternalAPIService
service = ExternalAPIService()
user_data = service.get_user_data(1)
# 验证API被调用
mock_get.assert_called_once_with('https://api.example.com/users/1')
# 验证返回数据
assert user_data['id'] == 1
assert user_data['name'] == 'John Doe'
assert user_data['email'] == 'john@example.com'
def test_external_api_failure(mocker):
"""测试外部API调用失败"""
# 模拟requests.get抛出异常
with patch('requests.get', side_effect=requests.exceptions.RequestException('Network error')):
from app.services import ExternalAPIService
service = ExternalAPIService()
# 应该处理异常并返回默认值
user_data = service.get_user_data(1)
assert user_data is None
def test_external_api_rate_limit(mocker):
"""测试API速率限制"""
# 模拟速率限制响应
mock_response = Mock()
mock_response.status_code = 429
mock_response.headers = {'Retry-After': '60'}
with patch('requests.get', return_value=mock_response):
from app.services import ExternalAPIService
service = ExternalAPIService()
# 应该处理速率限制
result = service.call_with_retry()
# 根据实现逻辑断言
assert result.get('retry_after') == 60
@pytest.mark.parametrize('status_code,expected_result', [
(200, 'success'),
(404, 'not_found'),
(500, 'error'),
(401, 'unauthorized'),
])
def test_api_status_codes(status_code, expected_result, mocker):
"""参数化测试不同的API状态码"""
mock_response = Mock()
mock_response.status_code = status_code
with patch('requests.get', return_value=mock_response):
from app.services import ExternalAPIService
service = ExternalAPIService()
result = service.handle_response(mock_response)
assert result == expected_result
# tests/test_file_service.py
import pytest
from unittest.mock import Mock, patch, mock_open
import os
import tempfile
import shutil
class TestFileService:
def test_save_file_success(self, mocker):
"""测试文件保存成功"""
# 模拟文件对象
mock_file = Mock()
mock_file.filename = 'test.txt'
mock_file.read.return_value = b'file content'
# 模拟os.path.exists返回False(目录不存在)
mocker.patch('os.path.exists', return_value=False)
# 模拟os.makedirs
mock_makedirs = mocker.patch('os.makedirs')
# 模拟open函数
mock_file_obj = mock_open()
with patch('builtins.open', mock_file_obj):
from app.services import FileService
service = FileService()
file_path = service.save_file(mock_file, '/uploads')
# 验证目录被创建
mock_makedirs.assert_called_once_with('/uploads', exist_ok=True)
# 验证文件被写入
mock_file_obj.assert_called_once_with('/uploads/test.txt', 'wb')
# 验证写入内容
handle = mock_file_obj()
handle.write.assert_called_once_with(b'file content')
# 验证返回路径
assert file_path == '/uploads/test.txt'
def test_save_file_io_error(self, mocker):
"""测试文件保存IO错误"""
mock_file = Mock()
mock_file.filename = 'test.txt'
mock_file.read.return_value = b'file content'
# 模拟open抛出IOError
with patch('builtins.open', side_effect=IOError('Disk full')):
from app.services import FileService
service = FileService()
with pytest.raises(IOError, match='Disk full'):
service.save_file(mock_file, '/uploads')
def test_delete_file_success(self, mocker):
"""测试文件删除成功"""
# 模拟os.path.exists返回True(文件存在)
mocker.patch('os.path.exists', return_value=True)
# 模拟os.remove
mock_remove = mocker.patch('os.remove')
from app.services import FileService
service = FileService()
result = service.delete_file('/uploads/test.txt')
# 验证os.remove被调用
mock_remove.assert_called_once_with('/uploads/test.txt')
# 验证返回结果
assert result == True
def test_delete_file_not_found(self, mocker):
"""测试删除不存在的文件"""
# 模拟os.path.exists返回False
mocker.patch('os.path.exists', return_value=False)
from app.services import FileService
service = FileService()
result = service.delete_file('/uploads/nonexistent.txt')
# 应该返回False而不是抛出异常
assert result == False
def test_get_file_size(self, mocker):
"""测试获取文件大小"""
# 模拟os.path.getsize
mocker.patch('os.path.getsize', return_value=1024)
from app.services import FileService
service = FileService()
size = service.get_file_size('/uploads/test.txt')
assert size == 1024
# 测试文件不存在的情况
mocker.patch('os.path.getsize', side_effect=FileNotFoundError)
size = service.get_file_size('/uploads/nonexistent.txt')
assert size == 0
# tests/test_auth_integration.py
import pytest
from flask import session, url_for
from flask_login import current_user
def test_login_success(client, db_session):
"""测试登录成功"""
from app.models import User
# 创建测试用户
user = User(
username='testuser',
email='test@example.com'
)
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
# 登录请求
response = client.post('/login', data={
'username': 'testuser',
'password': 'password123'
}, follow_redirects=True)
assert response.status_code == 200
# 验证登录状态
with client.session_transaction() as sess:
assert 'user_id' in sess
assert sess['user_id'] == str(user.id)
# 验证重定向到正确页面
assert b'Dashboard' in response.data
assert b'Welcome testuser' in response.data
def test_login_failure(client, db_session):
"""测试登录失败"""
from app.models import User
user = User(
username='testuser',
email='test@example.com'
)
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
# 错误密码
response = client.post('/login', data={
'username': 'testuser',
'password': 'wrongpassword'
})
assert response.status_code == 200
assert b'Invalid username or password' in response.data
# 不存在的用户
response = client.post('/login', data={
'username': 'nonexistent',
'password': 'password123'
})
assert b'Invalid username or password' in response.data
def test_logout(client, auth_user):
"""测试注销"""
# 先登录
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})
# 验证登录状态
response = client.get('/profile')
assert response.status_code == 200
# 注销
response = client.get('/logout', follow_redirects=True)
assert response.status_code == 200
assert b'You have been logged out' in response.data
# 验证无法访问需要登录的页面
response = client.get('/profile')
assert response.status_code != 200 # 应该是重定向或403
def test_protected_route_access(client, auth_user):
"""测试受保护路由的访问"""
# 未登录时访问受保护页面
response = client.get('/dashboard', follow_redirects=False)
# 应该重定向到登录页面
assert response.status_code == 302
assert '/login' in response.location
# 登录后访问
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})
response = client.get('/dashboard')
assert response.status_code == 200
assert b'Dashboard' in response.data
@pytest.fixture
def auth_user(db_session):
"""创建认证用户fixture"""
from app.models import User
user = User(
username='testuser',
email='test@example.com'
)
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
return user
# tests/test_sqlalchemy_integration.py
import pytest
from sqlalchemy.exc import IntegrityError
from app.models import User, Post, db
def test_database_transaction_rollback(db_session):
"""测试数据库事务回滚"""
user = User(username='testuser', email='test@example.com')
db_session.session.add(user)
# 验证用户未提交
users = User.query.all()
assert len(users) == 0
# 提交
db_session.session.commit()
users = User.query.all()
assert len(users) == 1
# 测试回滚
new_user = User(username='another', email='another@example.com')
db_session.session.add(new_user)
db_session.session.rollback()
users = User.query.all()
assert len(users) == 1 # 只有第一个用户
def test_unique_constraint(db_session):
"""测试唯一性约束"""
user1 = User(username='uniqueuser', email='unique@example.com')
db_session.session.add(user1)
db_session.session.commit()
# 尝试创建相同用户名的用户
user2 = User(username='uniqueuser', email='different@example.com')
db_session.session.add(user2)
# 应该抛出完整性错误
with pytest.raises(IntegrityError):
db_session.session.commit()
# 回滚以清理错误状态
db_session.session.rollback()
def test_foreign_key_constraint(db_session):
"""测试外键约束"""
# 尝试创建没有用户的帖子
post = Post(title='Test', content='Content')
db_session.session.add(post)
# 应该抛出完整性错误
with pytest.raises(IntegrityError):
db_session.session.commit()
db_session.session.rollback()
# 创建用户后应该成功
user = User(username='author', email='author@example.com')
db_session.session.add(user)
db_session.session.commit()
post = Post(title='Test', content='Content', author_id=user.id)
db_session.session.add(post)
db_session.session.commit()
assert post.id is not None
def test_database_relationship_loading(db_session):
"""测试关系加载策略"""
user = User(username='test', email='test@example.com')
post1 = Post(title='Post 1', content='Content 1', author=user)
post2 = Post(title='Post 2', content='Content 2', author=user)
db_session.session.add_all([user, post1, post2])
db_session.session.commit()
# 重新查询用户,测试延迟加载
db_user = User.query.get(user.id)
# posts属性应该可访问
assert len(db_user.posts) == 2
# 测试查询优化(joinedload)
from sqlalchemy.orm import joinedload
optimized_user = User.query.options(
joinedload(User.posts)
).filter_by(id=user.id).first()
# 应该已经加载了posts
assert len(optimized_user.posts) == 2
# tests/test_api.py
import json
import pytest
class TestUserAPI:
"""测试用户API"""
def test_get_users(self, client, db_session):
"""测试获取用户列表"""
from app.models import User
# 创建测试用户
users = [
User(username='user1', email='user1@example.com'),
User(username='user2', email='user2@example.com')
]
db_session.session.add_all(users)
db_session.session.commit()
response = client.get('/api/users')
assert response.status_code == 200
data = json.loads(response.data)
assert 'users' in data
assert len(data['users']) == 2
# 验证用户数据
user_data = data['users'][0]
assert 'id' in user_data
assert 'username' in user_data
assert 'email' in user_data
assert 'created_at' in user_data
def test_get_user(self, client, db_session):
"""测试获取单个用户"""
from app.models import User
user = User(username='testuser', email='test@example.com')
db_session.session.add(user)
db_session.session.commit()
response = client.get(f'/api/users/{user.id}')
assert response.status_code == 200
data = json.loads(response.data)
assert data['user']['id'] == user.id
assert data['user']['username'] == 'testuser'
assert data['user']['email'] == 'test@example.com'
def test_create_user(self, client, db_session):
"""测试创建用户"""
user_data = {
'username': 'newuser',
'email': 'newuser@example.com',
'password': 'securepassword123'
}
response = client.post(
'/api/users',
data=json.dumps(user_data),
content_type='application/json'
)
assert response.status_code == 201
data = json.loads(response.data)
assert 'user' in data
assert data['user']['username'] == 'newuser'
assert data['user']['email'] == 'newuser@example.com'
# 验证用户已创建
from app.models import User
db_user = User.query.filter_by(username='newuser').first()
assert db_user is not None
assert db_user.email == 'newuser@example.com'
def test_update_user(self, client, db_session, auth_token):
"""测试更新用户"""
from app.models import User
user = User(username='oldname', email='old@example.com')
db_session.session.add(user)
db_session.session.commit()
update_data = {
'username': 'newname',
'email': 'new@example.com'
}
headers = {'Authorization': f'Bearer {auth_token}'}
response = client.put(
f'/api/users/{user.id}',
data=json.dumps(update_data),
content_type='application/json',
headers=headers
)
assert response.status_code == 200
data = json.loads(response.data)
assert data['user']['username'] == 'newname'
assert data['user']['email'] == 'new@example.com'
# 验证数据库已更新
db_user = User.query.get(user.id)
assert db_user.username == 'newname'
def test_delete_user(self, client, db_session, auth_token):
"""测试删除用户"""
from app.models import User
user = User(username='todelete', email='delete@example.com')
db_session.session.add(user)
db_session.session.commit()
user_id = user.id
headers = {'Authorization': f'Bearer {auth_token}'}
response = client.delete(
f'/api/users/{user_id}',
headers=headers
)
assert response.status_code == 204
# 验证用户已删除
db_user = User.query.get(user_id)
assert db_user is None
def test_api_authentication_required(self, client, db_session):
"""测试API认证要求"""
from app.models import User
user = User(username='test', email='test@example.com')
db_session.session.add(user)
db_session.session.commit()
# 未提供认证令牌
response = client.put(
f'/api/users/{user.id}',
data=json.dumps({'username': 'updated'}),
content_type='application/json'
)
assert response.status_code == 401
data = json.loads(response.data)
assert 'error' in data
assert 'message' in data
def test_api_validation_errors(self, client):
"""测试API验证错误"""
# 无效数据
invalid_data = {
'username': '', # 空用户名
'email': 'invalid-email',
'password': '123' # 太短
}
response = client.post(
'/api/users',
data=json.dumps(invalid_data),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'errors' in data
assert 'username' in data['errors']
assert 'email' in data['errors']
assert 'password' in data['errors']
@pytest.fixture
def auth_token(client, db_session):
"""获取认证令牌fixture"""
from app.models import User
# 创建用户
user = User(username='testuser', email='test@example.com')
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
# 获取令牌
response = client.post('/api/auth/login', data=json.dumps({
'username': 'testuser',
'password': 'password123'
}), content_type='application/json')
data = json.loads(response.data)
return data['access_token']
# 运行测试并生成覆盖率报告
pytest --cov=app --cov-report=html --cov-report=term-missing
# 只运行特定模块的测试
pytest tests/test_models.py --cov=app.models
# 设置覆盖率阈值
pytest --cov=app --cov-fail-under=80
# 生成XML报告(用于CI/CD集成)
pytest --cov=app --cov-report=xml
# .coveragerc
[run]
source = app
omit =
*/tests/*
*/migrations/*
*/__pycache__/*
*/venv/*
*/env/*
[report]
# 覆盖率阈值
fail_under = 80
# 排除行
exclude_lines =
# 不要检查pragma注释
pragma: no cover
# 不要检查pass语句
^\s*pass\s*$
# 不要检查raise NotImplementedError
^\s*raise NotImplementedError\s*$
# 不要检查抽象方法
@abstractmethod
# 显示哪些行未覆盖
show_missing = True
# 排序方式
sort = Cover
[html]
# HTML报告目录
directory = htmlcov
# 显示哪些分支未覆盖
show_contexts = True
| 覆盖率 | 等级 | 建议 |
|---|---|---|
| 90%+ | 优秀 | 继续保持,关注边缘情况 |
| 80-89% | 良好 | 覆盖主要功能,可进一步提升 |
| 70-79% | 一般 | 需要增加测试,特别是关键路径 |
| <70% | 不足 | 急需增加测试,存在风险 |
# .github/workflows/test.yml
name: Tests
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
test:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:13
env:
POSTGRES_PASSWORD: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
strategy:
matrix:
python-version: [3.8, 3.9, "3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python {{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: {{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
- name: Run linting
run: |
pip install flake8 black isort
flake8 app tests
black --check app tests
isort --check-only app tests
- name: Run tests with pytest
env:
DATABASE_URL: postgresql://postgres:postgres@localhost:5432/test_db
SECRET_KEY: ${{ secrets.TEST_SECRET_KEY }}
run: |
pytest --cov=app --cov-report=xml --cov-report=term-missing
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
- name: Check coverage threshold
run: |
coverage report --fail-under=80
# Dockerfile.test
FROM python:3.9-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt requirements-test.txt ./
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt -r requirements-test.txt
# 复制应用代码
COPY . .
# 运行测试
CMD ["pytest", "-v", "--cov=app", "--cov-report=term-missing"]
# docker-compose.test.yml
version: '3.8'
services:
db:
image: postgres:13
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: test_db
ports:
- "5432:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 10s
timeout: 5s
retries: 5
test:
build:
context: .
dockerfile: Dockerfile.test
depends_on:
db:
condition: service_healthy
environment:
DATABASE_URL: postgresql://postgres:postgres@db:5432/test_db
SECRET_KEY: test-secret-key
volumes:
- .:/app
command: pytest -v --cov=app --cov-report=html
# 好的测试名称
def test_user_creation_success(): ...
def test_login_with_invalid_credentials(): ...
def test_get_user_by_id_returns_correct_user(): ...
def test_update_profile_with_empty_name_fails(): ...
# 使用描述性的测试类名
class TestUserAuthentication:
def test_valid_login_creates_session(self): ...
def test_invalid_login_returns_error(self): ...
# 参数化测试
@pytest.mark.parametrize('input,expected', [
('admin', True),
('user', False),
('', False),
])
def test_is_admin_role(input, expected): ...
有几种策略可以隔离测试数据:
# 使用事务回滚
@pytest.fixture
def db_session(app):
"""在每个测试后回滚数据库"""
with app.app_context():
connection = db.engine.connect()
transaction = connection.begin()
# 创建所有表
db.create_all()
yield db
# 测试后回滚
transaction.rollback()
connection.close()
db.session.remove()
# 使用临时数据库
@pytest.fixture(scope='session')
def test_db():
"""创建临时测试数据库"""
import tempfile
import os
# 创建临时数据库文件
db_fd, db_path = tempfile.mkstemp(suffix='.db')
# 使用临时数据库路径
app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{db_path}'
yield
# 清理临时文件
os.close(db_fd)
os.unlink(db_path)
几种测试需要认证功能的方法:
# 方法1:使用fixture创建已登录用户
@pytest.fixture
def authenticated_client(client, db_session):
"""创建已认证的测试客户端"""
from app.models import User
# 创建用户
user = User(username='testuser', email='test@example.com')
user.set_password('password123')
db_session.session.add(user)
db_session.session.commit()
# 登录
client.post('/login', data={
'username': 'testuser',
'password': 'password123'
})
return client
# 方法2:直接设置session
@pytest.fixture
def logged_in_user(client, db_session):
"""通过设置session模拟登录状态"""
from app.models import User
from flask import session
user = User(username='testuser', email='test@example.com')
db_session.session.add(user)
db_session.session.commit()
with client.session_transaction() as sess:
sess['user_id'] = user.id
return user
# 方法3:使用装饰器
def login_user(client, username, password):
"""登录用户的辅助函数"""
return client.post('/login', data={
'username': username,
'password': password
}, follow_redirects=True)
# 在测试中使用
def test_protected_page(authenticated_client):
response = authenticated_client.get('/dashboard')
assert response.status_code == 200
assert b'Dashboard' in response.data
# 测试不同权限的用户
@pytest.mark.parametrize('role,expected_status', [
('admin', 200),
('user', 200),
('guest', 403),
])
def test_admin_page_access(client, role, expected_status):
# 创建不同角色的用户并登录
user = create_user_with_role(role)
login_user(client, user.username, 'password')
response = client.get('/admin')
assert response.status_code == expected_status
处理慢速测试的策略:
# 标记慢速测试
import pytest
@pytest.mark.slow
def test_external_api_integration():
"""与真实外部API交互的测试"""
# 实际调用外部API
result = call_external_api()
assert result is not None
# 运行测试时排除慢速测试
# pytest -m "not slow"
# 使用模拟
from unittest.mock import Mock, patch
def test_external_api_with_mock():
"""使用模拟测试外部API逻辑"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {'data': 'test'}
with patch('requests.get', return_value=mock_response):
result = get_external_data()
assert result == {'data': 'test'}
# 使用测试双倍(Test Double)
class FakeExternalService:
"""外部服务的假实现"""
def get_data(self):
return {'data': 'fake data'}
def test_with_fake_service():
"""使用假服务测试"""
service = FakeExternalService()
result = service.get_data()
assert result['data'] == 'fake data'
# 设置超时
import pytest
import time
@pytest.mark.timeout(5) # 5秒超时
def test_with_timeout():
# 如果测试超过5秒,自动失败
time.sleep(10) # 这会触发超时失败
测试异步Flask应用的方法:
# 安装异步测试支持
# pip install pytest-asyncio
import pytest
import asyncio
from flask import Flask
# 创建异步Flask应用
app = Flask(__name__)
@app.route('/async')
async def async_endpoint():
await asyncio.sleep(0.1)
return 'Async Response'
# 测试异步端点
@pytest.mark.asyncio
async def test_async_endpoint():
"""测试异步端点"""
test_client = app.test_client()
# 使用async/await
response = await test_client.get('/async')
assert response.status_code == 200
assert await response.get_data(as_text=True) == 'Async Response'
# 测试协程函数
async def async_function():
await asyncio.sleep(0.1)
return 'result'
@pytest.mark.asyncio
async def test_async_function():
result = await async_function()
assert result == 'result'
# 使用事件循环fixture
@pytest.fixture
def event_loop():
"""创建事件循环fixture"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
# 测试异步数据库操作
@pytest.mark.asyncio
async def test_async_database(db_session):
"""测试异步数据库操作"""
from app.models import User
# 异步创建用户
user = User(username='asyncuser', email='async@example.com')
db_session.session.add(user)
# 异步提交
await asyncio.to_thread(db_session.session.commit)
# 异步查询
users = await asyncio.to_thread(
lambda: User.query.filter_by(username='asyncuser').first()
)
assert users is not None
assert users.username == 'asyncuser'
# 使用httpx测试异步HTTP客户端
import httpx
import pytest
@pytest.mark.asyncio
async def test_external_api_async():
"""异步测试外部API"""
async with httpx.AsyncClient() as client:
response = await client.get('https://api.example.com/data')
assert response.status_code == 200
data = response.json()
assert 'data' in data