Flask单元测试

测试的重要性:高质量的测试是保证应用稳定性的关键,可以及早发现问题,提高代码质量。

为什么需要测试?

测试在软件开发中扮演着至关重要的角色:

及早发现错误

在代码部署前发现问题,降低修复成本

代码重构安全网

修改代码时确保功能不受影响

文档作用

测试代码是最好的功能文档

测试金字塔

端到端测试 (E2E)
数量少,覆盖完整流程
集成测试
测试组件间交互
单元测试
数量最多,测试单个函数/方法

测试框架选择

框架 特点 安装命令 适用场景
unittest Python标准库,无需额外安装 内置 基础测试,简单项目
pytest 功能强大,语法简洁,插件丰富 pip install pytest 大多数Flask项目首选
nose2 unittest扩展,兼容unittest pip install nose2 从unittest迁移

pytest基础配置

1. 安装pytest和相关插件

# 安装pytest和Flask测试扩展
pip install pytest pytest-flask pytest-cov

# 用于模拟外部依赖
pip install pytest-mock

# 数据库测试
pip install pytest-flask-sqlalchemy

# 可选:增强断言信息
pip install pytest-assert-utils

2. pytest配置文件

# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
    --verbose
    --tb=short
    --strict-markers

markers =
    slow: marks tests as slow (deselect with '-m "not slow"')
    integration: integration tests
    e2e: end-to-end tests

3. conftest.py - 测试配置

# tests/conftest.py
import pytest
import sys
import os

# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from app import create_app
from app.extensions import db

@pytest.fixture(scope='session')
def app():
    """创建测试应用实例"""
    # 使用测试配置
    app = create_app({
        'TESTING': True,
        'SQLALCHEMY_DATABASE_URI': 'sqlite:///:memory:',
        'SQLALCHEMY_TRACK_MODIFICATIONS': False,
        'SECRET_KEY': 'test-secret-key',
        'WTF_CSRF_ENABLED': False  # 禁用CSRF便于测试
    })

    return app

@pytest.fixture(scope='function')
def client(app):
    """创建测试客户端"""
    with app.test_client() as client:
        yield client

@pytest.fixture(scope='function')
def db_session(app):
    """创建数据库会话"""
    with app.app_context():
        db.create_all()
        yield db
        db.session.remove()
        db.drop_all()

@pytest.fixture(scope='function')
def runner(app):
    """创建CLI运行器"""
    return app.test_cli_runner()

@pytest.fixture(scope='session', autouse=True)
def setup_test_environment():
    """设置测试环境"""
    # 测试前执行
    print("\n=== 设置测试环境 ===")
    yield
    # 测试后执行
    print("\n=== 清理测试环境 ===")

单元测试基础

1. 第一个Flask测试

示例:测试首页路由
通过
# tests/test_basic.py
import pytest

def test_index_route(client):
    """测试首页是否正常返回"""
    response = client.get('/')

    # 断言状态码
    assert response.status_code == 200

    # 断言响应内容
    assert b'Welcome' in response.data or b'Home' in response.data

    # 断言内容类型
    assert 'text/html' in response.content_type

def test_404_error(client):
    """测试不存在的页面返回404"""
    response = client.get('/nonexistent-page')
    assert response.status_code == 404

def test_redirect(client):
    """测试重定向"""
    response = client.get('/old-url')

    # 断言重定向
    assert response.status_code == 302
    assert '/new-url' in response.location

    # 跟随重定向
    response = client.get('/old-url', follow_redirects=True)
    assert response.status_code == 200

2. 测试配置和环境

# tests/test_config.py
import pytest

def test_app_config(app):
    """测试应用配置"""
    assert app.config['TESTING'] == True
    assert app.config['SECRET_KEY'] == 'test-secret-key'
    assert app.config['SQLALCHEMY_DATABASE_URI'] == 'sqlite:///:memory:'
    assert app.config['WTF_CSRF_ENABLED'] == False

def test_app_debug_mode(app):
    """测试调试模式"""
    # 生产环境应该关闭调试模式
    if app.config.get('ENVIRONMENT') == 'production':
        assert app.debug == False
        assert app.config['TESTING'] == False
    else:
        # 开发环境可以开启调试
        pass

def test_app_extensions(app):
    """测试扩展是否正常加载"""
    assert hasattr(app, 'extensions')
    assert 'sqlalchemy' in app.extensions

    # 检查数据库扩展
    from app.extensions import db
    assert db.engine.url.database == ':memory:'  # 内存数据库

测试Flask视图

1. 测试GET请求

# tests/test_views.py
import json

def test_get_user_profile(client):
    """测试获取用户资料"""
    # 模拟登录
    client.post('/login', data={
        'username': 'testuser',
        'password': 'password123'
    })

    response = client.get('/profile')

    assert response.status_code == 200
    assert b'Profile' in response.data
    assert b'testuser' in response.data

def test_get_with_query_params(client):
    """测试带查询参数的GET请求"""
    response = client.get('/search?q=flask&page=1')

    assert response.status_code == 200
    assert b'flask' in response.data.lower()

    # 检查查询参数
    assert b'page 1' in response.data

def test_get_json_api(client):
    """测试返回JSON的API端点"""
    response = client.get('/api/users/1')

    assert response.status_code == 200
    assert response.content_type == 'application/json'

    data = json.loads(response.data)

    # 断言JSON结构
    assert 'id' in data
    assert 'name' in data
    assert 'email' in data

    # 断言具体值
    assert data['id'] == 1
    assert isinstance(data['name'], str)

2. 测试POST请求

示例:测试用户注册
通过
# tests/test_auth.py
def test_user_registration_success(client, db_session):
    """测试用户注册成功"""
    data = {
        'username': 'newuser',
        'email': 'newuser@example.com',
        'password': 'securepassword123',
        'confirm_password': 'securepassword123'
    }

    response = client.post('/register',
                          data=data,
                          follow_redirects=True)

    # 断言注册成功
    assert response.status_code == 200
    assert b'Registration successful' in response.data
    assert b'Welcome' in response.data

    # 检查用户是否被创建
    from app.models import User
    user = User.query.filter_by(username='newuser').first()
    assert user is not None
    assert user.email == 'newuser@example.com'

def test_user_registration_failure(client):
    """测试用户注册失败情况"""
    # 测试密码不匹配
    data = {
        'username': 'newuser',
        'email': 'newuser@example.com',
        'password': 'password123',
        'confirm_password': 'differentpassword'
    }

    response = client.post('/register', data=data)

    assert response.status_code == 200  # 应该停留在注册页面
    assert b'Passwords do not match' in response.data

    # 测试邮箱已存在
    data = {
        'username': 'anotheruser',
        'email': 'existing@example.com',  # 假设这个邮箱已存在
        'password': 'password123',
        'confirm_password': 'password123'
    }

    response = client.post('/register', data=data)
    assert b'Email already registered' in response.data

    # 测试缺少必要字段
    response = client.post('/register', data={'username': 'incomplete'})
    assert b'This field is required' in response.data

3. 测试表单验证

# tests/test_forms.py
import pytest
from app.forms import RegistrationForm, LoginForm
from wtforms.validators import ValidationError

def test_registration_form_valid():
    """测试注册表单验证通过"""
    form = RegistrationForm(data={
        'username': 'validuser',
        'email': 'valid@example.com',
        'password': 'ValidPass123',
        'confirm_password': 'ValidPass123'
    })

    assert form.validate() == True
    assert len(form.errors) == 0

def test_registration_form_invalid():
    """测试注册表单验证失败"""
    # 弱密码
    form = RegistrationForm(data={
        'username': 'test',
        'email': 'test@example.com',
        'password': '123',
        'confirm_password': '123'
    })

    assert form.validate() == False
    assert 'password' in form.errors
    assert len(form.errors['password']) > 0

    # 邮箱格式错误
    form = RegistrationForm(data={
        'username': 'test',
        'email': 'invalid-email',
        'password': 'Password123',
        'confirm_password': 'Password123'
    })

    assert form.validate() == False
    assert 'email' in form.errors

    # 密码不匹配
    form = RegistrationForm(data={
        'username': 'test',
        'email': 'test@example.com',
        'password': 'Password123',
        'confirm_password': 'Different123'
    })

    assert form.validate() == False
    assert 'confirm_password' in form.errors

@pytest.mark.parametrize('username,expected', [
    ('validname', True),
    ('a', False),  # 太短
    ('username_with_underscore', True),
    ('user@name', False),  # 包含特殊字符
    ('toolongusername12345', False),  # 太长
])
def test_username_validation(username, expected):
    """参数化测试用户名验证"""
    form = RegistrationForm(data={
        'username': username,
        'email': 'test@example.com',
        'password': 'Password123',
        'confirm_password': 'Password123'
    })

    # 只验证用户名字段
    form.validate_username(form.username)

    if expected:
        assert 'username' not in form.errors
    else:
        assert 'username' in form.errors

数据库测试

1. 测试模型

# tests/test_models.py
import pytest
from datetime import datetime, timedelta
from app.models import User, Post, db

def test_user_creation(db_session):
    """测试用户模型创建"""
    user = User(
        username='testuser',
        email='test@example.com',
        password_hash='hashed_password'
    )

    db_session.session.add(user)
    db_session.session.commit()

    # 验证用户属性
    assert user.id is not None
    assert user.username == 'testuser'
    assert user.email == 'test@example.com'
    assert user.created_at is not None
    assert user.is_active == True

    # 测试字符串表示
    assert str(user) == f'<User testuser>'

    # 测试密码验证
    user.set_password('mypassword')
    assert user.check_password('mypassword') == True
    assert user.check_password('wrongpassword') == False

def test_user_relationships(db_session):
    """测试用户关系(一对多)"""
    user = User(username='author', email='author@example.com')

    # 创建关联的帖子
    post1 = Post(title='First Post', content='Content 1', author=user)
    post2 = Post(title='Second Post', content='Content 2', author=user)

    db_session.session.add_all([user, post1, post2])
    db_session.session.commit()

    # 验证关系
    assert len(user.posts) == 2
    assert post1.author == user
    assert post2.author == user
    assert post1 in user.posts
    assert post2 in user.posts

def test_post_timestamps(db_session):
    """测试帖子的时间戳"""
    user = User(username='test', email='test@example.com')
    post = Post(title='Test Post', content='Content', author=user)

    before_creation = datetime.utcnow()
    db_session.session.add(post)
    db_session.session.commit()
    after_creation = datetime.utcnow()

    # 验证创建时间
    assert post.created_at >= before_creation
    assert post.created_at <= after_creation

    # 测试更新时间
    original_updated_at = post.updated_at

    # 等待一小段时间
    import time
    time.sleep(0.01)

    # 更新帖子
    post.content = 'Updated content'
    db_session.session.commit()

    assert post.updated_at > original_updated_at

@pytest.mark.parametrize('username,email,expected', [
    ('user1', 'user1@example.com', True),  # 有效
    ('', 'user@example.com', False),  # 空用户名
    ('user', '', False),  # 空邮箱
    ('user', 'invalid-email', False),  # 无效邮箱
])
def test_user_validation(username, email, expected, db_session):
    """测试用户模型验证"""
    user = User(username=username, email=email)

    if expected:
        db_session.session.add(user)
        db_session.session.commit()
        assert user.id is not None
    else:
        with pytest.raises(Exception):
            db_session.session.add(user)
            db_session.session.commit()
            db_session.session.rollback()

2. 测试数据库操作

示例:测试CRUD操作
通过
# tests/test_crud.py
import pytest
from app.models import User, Post, Comment
from app.services import UserService, PostService

class TestUserService:
    """测试用户服务"""

    def test_create_user(self, db_session):
        """测试创建用户"""
        service = UserService(db_session)

        user_data = {
            'username': 'newuser',
            'email': 'newuser@example.com',
            'password': 'password123'
        }

        user = service.create_user(**user_data)

        assert user.id is not None
        assert user.username == 'newuser'
        assert user.email == 'newuser@example.com'
        assert user.check_password('password123')

        # 验证用户已保存到数据库
        db_user = User.query.get(user.id)
        assert db_user == user

    def test_get_user_by_email(self, db_session):
        """测试通过邮箱获取用户"""
        # 先创建用户
        user = User(
            username='testuser',
            email='test@example.com'
        )
        user.set_password('password123')
        db_session.session.add(user)
        db_session.session.commit()

        service = UserService(db_session)
        found_user = service.get_user_by_email('test@example.com')

        assert found_user is not None
        assert found_user.id == user.id
        assert found_user.username == 'testuser'

        # 测试不存在的邮箱
        not_found = service.get_user_by_email('nonexistent@example.com')
        assert not_found is None

    def test_update_user(self, db_session):
        """测试更新用户"""
        user = User(
            username='oldname',
            email='old@example.com'
        )
        db_session.session.add(user)
        db_session.session.commit()

        service = UserService(db_session)

        # 更新用户信息
        updated_user = service.update_user(
            user.id,
            username='newname',
            email='new@example.com'
        )

        assert updated_user.username == 'newname'
        assert updated_user.email == 'new@example.com'

        # 验证数据库已更新
        db_user = User.query.get(user.id)
        assert db_user.username == 'newname'

    def test_delete_user(self, db_session):
        """测试删除用户"""
        user = User(
            username='todelete',
            email='delete@example.com'
        )
        db_session.session.add(user)
        db_session.session.commit()

        user_id = user.id

        service = UserService(db_session)
        result = service.delete_user(user_id)

        assert result == True

        # 验证用户已删除
        deleted_user = User.query.get(user_id)
        assert deleted_user is None

        # 测试删除不存在的用户
        result = service.delete_user(99999)
        assert result == False

class TestPostService:
    """测试帖子服务"""

    def test_create_post(self, db_session):
        """测试创建帖子"""
        user = User(username='author', email='author@example.com')
        db_session.session.add(user)
        db_session.session.commit()

        service = PostService(db_session)

        post_data = {
            'title': 'Test Post',
            'content': 'This is a test post.',
            'author_id': user.id
        }

        post = service.create_post(**post_data)

        assert post.id is not None
        assert post.title == 'Test Post'
        assert post.content == 'This is a test post.'
        assert post.author_id == user.id
        assert post.is_published == True  # 默认发布

    def test_get_published_posts(self, db_session):
        """测试获取已发布的帖子"""
        user = User(username='author', email='author@example.com')

        # 创建多个帖子
        posts = [
            Post(title='Published 1', content='Content 1', author=user, is_published=True),
            Post(title='Published 2', content='Content 2', author=user, is_published=True),
            Post(title='Draft', content='Draft content', author=user, is_published=False),
        ]

        db_session.session.add_all([user] + posts)
        db_session.session.commit()

        service = PostService(db_session)
        published_posts = service.get_published_posts()

        assert len(published_posts) == 2
        assert all(post.is_published for post in published_posts)

        # 测试分页
        paginated = service.get_published_posts(page=1, per_page=1)
        assert len(paginated.items) == 1
        assert paginated.total == 2
        assert paginated.pages == 2

    def test_search_posts(self, db_session):
        """测试搜索帖子"""
        user = User(username='author', email='author@example.com')

        posts = [
            Post(title='Python Tutorial', content='Learn Python', author=user),
            Post(title='Flask Guide', content='Flask web framework', author=user),
            Post(title='Django vs Flask', content='Comparison', author=user),
        ]

        db_session.session.add_all([user] + posts)
        db_session.session.commit()

        service = PostService(db_session)

        # 搜索包含'Flask'的帖子
        flask_posts = service.search_posts('Flask')
        assert len(flask_posts) == 2  # 标题和内容中都包含

        # 搜索不存在的关键词
        no_results = service.search_posts('Nonexistent')
        assert len(no_results) == 0

        # 测试空搜索返回所有帖子
        all_posts = service.search_posts('')
        assert len(all_posts) == 3

模拟对象(Mocking)

1. 模拟外部API调用

# tests/test_external_api.py
import pytest
from unittest.mock import Mock, patch, MagicMock
import requests

def test_external_api_success(mocker):
    """测试外部API调用成功"""
    # 模拟requests.get返回成功响应
    mock_response = Mock()
    mock_response.status_code = 200
    mock_response.json.return_value = {
        'data': {
            'id': 1,
            'name': 'John Doe',
            'email': 'john@example.com'
        }
    }

    # 使用patch替换requests.get
    with patch('requests.get', return_value=mock_response) as mock_get:
        from app.services import ExternalAPIService
        service = ExternalAPIService()

        user_data = service.get_user_data(1)

        # 验证API被调用
        mock_get.assert_called_once_with('https://api.example.com/users/1')

        # 验证返回数据
        assert user_data['id'] == 1
        assert user_data['name'] == 'John Doe'
        assert user_data['email'] == 'john@example.com'

def test_external_api_failure(mocker):
    """测试外部API调用失败"""
    # 模拟requests.get抛出异常
    with patch('requests.get', side_effect=requests.exceptions.RequestException('Network error')):
        from app.services import ExternalAPIService
        service = ExternalAPIService()

        # 应该处理异常并返回默认值
        user_data = service.get_user_data(1)

        assert user_data is None

def test_external_api_rate_limit(mocker):
    """测试API速率限制"""
    # 模拟速率限制响应
    mock_response = Mock()
    mock_response.status_code = 429
    mock_response.headers = {'Retry-After': '60'}

    with patch('requests.get', return_value=mock_response):
        from app.services import ExternalAPIService
        service = ExternalAPIService()

        # 应该处理速率限制
        result = service.call_with_retry()

        # 根据实现逻辑断言
        assert result.get('retry_after') == 60

@pytest.mark.parametrize('status_code,expected_result', [
    (200, 'success'),
    (404, 'not_found'),
    (500, 'error'),
    (401, 'unauthorized'),
])
def test_api_status_codes(status_code, expected_result, mocker):
    """参数化测试不同的API状态码"""
    mock_response = Mock()
    mock_response.status_code = status_code

    with patch('requests.get', return_value=mock_response):
        from app.services import ExternalAPIService
        service = ExternalAPIService()

        result = service.handle_response(mock_response)

        assert result == expected_result

2. 模拟文件系统操作

示例:测试文件上传服务
需要修复
# tests/test_file_service.py
import pytest
from unittest.mock import Mock, patch, mock_open
import os
import tempfile
import shutil

class TestFileService:

    def test_save_file_success(self, mocker):
        """测试文件保存成功"""
        # 模拟文件对象
        mock_file = Mock()
        mock_file.filename = 'test.txt'
        mock_file.read.return_value = b'file content'

        # 模拟os.path.exists返回False(目录不存在)
        mocker.patch('os.path.exists', return_value=False)

        # 模拟os.makedirs
        mock_makedirs = mocker.patch('os.makedirs')

        # 模拟open函数
        mock_file_obj = mock_open()

        with patch('builtins.open', mock_file_obj):
            from app.services import FileService
            service = FileService()

            file_path = service.save_file(mock_file, '/uploads')

            # 验证目录被创建
            mock_makedirs.assert_called_once_with('/uploads', exist_ok=True)

            # 验证文件被写入
            mock_file_obj.assert_called_once_with('/uploads/test.txt', 'wb')

            # 验证写入内容
            handle = mock_file_obj()
            handle.write.assert_called_once_with(b'file content')

            # 验证返回路径
            assert file_path == '/uploads/test.txt'

    def test_save_file_io_error(self, mocker):
        """测试文件保存IO错误"""
        mock_file = Mock()
        mock_file.filename = 'test.txt'
        mock_file.read.return_value = b'file content'

        # 模拟open抛出IOError
        with patch('builtins.open', side_effect=IOError('Disk full')):
            from app.services import FileService
            service = FileService()

            with pytest.raises(IOError, match='Disk full'):
                service.save_file(mock_file, '/uploads')

    def test_delete_file_success(self, mocker):
        """测试文件删除成功"""
        # 模拟os.path.exists返回True(文件存在)
        mocker.patch('os.path.exists', return_value=True)

        # 模拟os.remove
        mock_remove = mocker.patch('os.remove')

        from app.services import FileService
        service = FileService()

        result = service.delete_file('/uploads/test.txt')

        # 验证os.remove被调用
        mock_remove.assert_called_once_with('/uploads/test.txt')

        # 验证返回结果
        assert result == True

    def test_delete_file_not_found(self, mocker):
        """测试删除不存在的文件"""
        # 模拟os.path.exists返回False
        mocker.patch('os.path.exists', return_value=False)

        from app.services import FileService
        service = FileService()

        result = service.delete_file('/uploads/nonexistent.txt')

        # 应该返回False而不是抛出异常
        assert result == False

    def test_get_file_size(self, mocker):
        """测试获取文件大小"""
        # 模拟os.path.getsize
        mocker.patch('os.path.getsize', return_value=1024)

        from app.services import FileService
        service = FileService()

        size = service.get_file_size('/uploads/test.txt')

        assert size == 1024

        # 测试文件不存在的情况
        mocker.patch('os.path.getsize', side_effect=FileNotFoundError)

        size = service.get_file_size('/uploads/nonexistent.txt')

        assert size == 0

测试Flask扩展

1. 测试Flask-Login

# tests/test_auth_integration.py
import pytest
from flask import session, url_for
from flask_login import current_user

def test_login_success(client, db_session):
    """测试登录成功"""
    from app.models import User

    # 创建测试用户
    user = User(
        username='testuser',
        email='test@example.com'
    )
    user.set_password('password123')
    db_session.session.add(user)
    db_session.session.commit()

    # 登录请求
    response = client.post('/login', data={
        'username': 'testuser',
        'password': 'password123'
    }, follow_redirects=True)

    assert response.status_code == 200

    # 验证登录状态
    with client.session_transaction() as sess:
        assert 'user_id' in sess
        assert sess['user_id'] == str(user.id)

    # 验证重定向到正确页面
    assert b'Dashboard' in response.data
    assert b'Welcome testuser' in response.data

def test_login_failure(client, db_session):
    """测试登录失败"""
    from app.models import User

    user = User(
        username='testuser',
        email='test@example.com'
    )
    user.set_password('password123')
    db_session.session.add(user)
    db_session.session.commit()

    # 错误密码
    response = client.post('/login', data={
        'username': 'testuser',
        'password': 'wrongpassword'
    })

    assert response.status_code == 200
    assert b'Invalid username or password' in response.data

    # 不存在的用户
    response = client.post('/login', data={
        'username': 'nonexistent',
        'password': 'password123'
    })

    assert b'Invalid username or password' in response.data

def test_logout(client, auth_user):
    """测试注销"""
    # 先登录
    client.post('/login', data={
        'username': 'testuser',
        'password': 'password123'
    })

    # 验证登录状态
    response = client.get('/profile')
    assert response.status_code == 200

    # 注销
    response = client.get('/logout', follow_redirects=True)

    assert response.status_code == 200
    assert b'You have been logged out' in response.data

    # 验证无法访问需要登录的页面
    response = client.get('/profile')
    assert response.status_code != 200  # 应该是重定向或403

def test_protected_route_access(client, auth_user):
    """测试受保护路由的访问"""
    # 未登录时访问受保护页面
    response = client.get('/dashboard', follow_redirects=False)

    # 应该重定向到登录页面
    assert response.status_code == 302
    assert '/login' in response.location

    # 登录后访问
    client.post('/login', data={
        'username': 'testuser',
        'password': 'password123'
    })

    response = client.get('/dashboard')
    assert response.status_code == 200
    assert b'Dashboard' in response.data

@pytest.fixture
def auth_user(db_session):
    """创建认证用户fixture"""
    from app.models import User

    user = User(
        username='testuser',
        email='test@example.com'
    )
    user.set_password('password123')
    db_session.session.add(user)
    db_session.session.commit()

    return user

2. 测试Flask-SQLAlchemy

# tests/test_sqlalchemy_integration.py
import pytest
from sqlalchemy.exc import IntegrityError
from app.models import User, Post, db

def test_database_transaction_rollback(db_session):
    """测试数据库事务回滚"""
    user = User(username='testuser', email='test@example.com')
    db_session.session.add(user)

    # 验证用户未提交
    users = User.query.all()
    assert len(users) == 0

    # 提交
    db_session.session.commit()

    users = User.query.all()
    assert len(users) == 1

    # 测试回滚
    new_user = User(username='another', email='another@example.com')
    db_session.session.add(new_user)
    db_session.session.rollback()

    users = User.query.all()
    assert len(users) == 1  # 只有第一个用户

def test_unique_constraint(db_session):
    """测试唯一性约束"""
    user1 = User(username='uniqueuser', email='unique@example.com')
    db_session.session.add(user1)
    db_session.session.commit()

    # 尝试创建相同用户名的用户
    user2 = User(username='uniqueuser', email='different@example.com')
    db_session.session.add(user2)

    # 应该抛出完整性错误
    with pytest.raises(IntegrityError):
        db_session.session.commit()

    # 回滚以清理错误状态
    db_session.session.rollback()

def test_foreign_key_constraint(db_session):
    """测试外键约束"""
    # 尝试创建没有用户的帖子
    post = Post(title='Test', content='Content')
    db_session.session.add(post)

    # 应该抛出完整性错误
    with pytest.raises(IntegrityError):
        db_session.session.commit()

    db_session.session.rollback()

    # 创建用户后应该成功
    user = User(username='author', email='author@example.com')
    db_session.session.add(user)
    db_session.session.commit()

    post = Post(title='Test', content='Content', author_id=user.id)
    db_session.session.add(post)
    db_session.session.commit()

    assert post.id is not None

def test_database_relationship_loading(db_session):
    """测试关系加载策略"""
    user = User(username='test', email='test@example.com')
    post1 = Post(title='Post 1', content='Content 1', author=user)
    post2 = Post(title='Post 2', content='Content 2', author=user)

    db_session.session.add_all([user, post1, post2])
    db_session.session.commit()

    # 重新查询用户,测试延迟加载
    db_user = User.query.get(user.id)

    # posts属性应该可访问
    assert len(db_user.posts) == 2

    # 测试查询优化(joinedload)
    from sqlalchemy.orm import joinedload

    optimized_user = User.query.options(
        joinedload(User.posts)
    ).filter_by(id=user.id).first()

    # 应该已经加载了posts
    assert len(optimized_user.posts) == 2

测试API端点

1. RESTful API测试

示例:测试完整的CRUD API
通过
# tests/test_api.py
import json
import pytest

class TestUserAPI:
    """测试用户API"""

    def test_get_users(self, client, db_session):
        """测试获取用户列表"""
        from app.models import User

        # 创建测试用户
        users = [
            User(username='user1', email='user1@example.com'),
            User(username='user2', email='user2@example.com')
        ]
        db_session.session.add_all(users)
        db_session.session.commit()

        response = client.get('/api/users')

        assert response.status_code == 200
        data = json.loads(response.data)

        assert 'users' in data
        assert len(data['users']) == 2

        # 验证用户数据
        user_data = data['users'][0]
        assert 'id' in user_data
        assert 'username' in user_data
        assert 'email' in user_data
        assert 'created_at' in user_data

    def test_get_user(self, client, db_session):
        """测试获取单个用户"""
        from app.models import User

        user = User(username='testuser', email='test@example.com')
        db_session.session.add(user)
        db_session.session.commit()

        response = client.get(f'/api/users/{user.id}')

        assert response.status_code == 200
        data = json.loads(response.data)

        assert data['user']['id'] == user.id
        assert data['user']['username'] == 'testuser'
        assert data['user']['email'] == 'test@example.com'

    def test_create_user(self, client, db_session):
        """测试创建用户"""
        user_data = {
            'username': 'newuser',
            'email': 'newuser@example.com',
            'password': 'securepassword123'
        }

        response = client.post(
            '/api/users',
            data=json.dumps(user_data),
            content_type='application/json'
        )

        assert response.status_code == 201
        data = json.loads(response.data)

        assert 'user' in data
        assert data['user']['username'] == 'newuser'
        assert data['user']['email'] == 'newuser@example.com'

        # 验证用户已创建
        from app.models import User
        db_user = User.query.filter_by(username='newuser').first()
        assert db_user is not None
        assert db_user.email == 'newuser@example.com'

    def test_update_user(self, client, db_session, auth_token):
        """测试更新用户"""
        from app.models import User

        user = User(username='oldname', email='old@example.com')
        db_session.session.add(user)
        db_session.session.commit()

        update_data = {
            'username': 'newname',
            'email': 'new@example.com'
        }

        headers = {'Authorization': f'Bearer {auth_token}'}

        response = client.put(
            f'/api/users/{user.id}',
            data=json.dumps(update_data),
            content_type='application/json',
            headers=headers
        )

        assert response.status_code == 200
        data = json.loads(response.data)

        assert data['user']['username'] == 'newname'
        assert data['user']['email'] == 'new@example.com'

        # 验证数据库已更新
        db_user = User.query.get(user.id)
        assert db_user.username == 'newname'

    def test_delete_user(self, client, db_session, auth_token):
        """测试删除用户"""
        from app.models import User

        user = User(username='todelete', email='delete@example.com')
        db_session.session.add(user)
        db_session.session.commit()

        user_id = user.id

        headers = {'Authorization': f'Bearer {auth_token}'}

        response = client.delete(
            f'/api/users/{user_id}',
            headers=headers
        )

        assert response.status_code == 204

        # 验证用户已删除
        db_user = User.query.get(user_id)
        assert db_user is None

    def test_api_authentication_required(self, client, db_session):
        """测试API认证要求"""
        from app.models import User

        user = User(username='test', email='test@example.com')
        db_session.session.add(user)
        db_session.session.commit()

        # 未提供认证令牌
        response = client.put(
            f'/api/users/{user.id}',
            data=json.dumps({'username': 'updated'}),
            content_type='application/json'
        )

        assert response.status_code == 401
        data = json.loads(response.data)
        assert 'error' in data
        assert 'message' in data

    def test_api_validation_errors(self, client):
        """测试API验证错误"""
        # 无效数据
        invalid_data = {
            'username': '',  # 空用户名
            'email': 'invalid-email',
            'password': '123'  # 太短
        }

        response = client.post(
            '/api/users',
            data=json.dumps(invalid_data),
            content_type='application/json'
        )

        assert response.status_code == 400
        data = json.loads(response.data)

        assert 'errors' in data
        assert 'username' in data['errors']
        assert 'email' in data['errors']
        assert 'password' in data['errors']

@pytest.fixture
def auth_token(client, db_session):
    """获取认证令牌fixture"""
    from app.models import User

    # 创建用户
    user = User(username='testuser', email='test@example.com')
    user.set_password('password123')
    db_session.session.add(user)
    db_session.session.commit()

    # 获取令牌
    response = client.post('/api/auth/login', data=json.dumps({
        'username': 'testuser',
        'password': 'password123'
    }), content_type='application/json')

    data = json.loads(response.data)
    return data['access_token']

测试覆盖率

1. 生成测试覆盖率报告

# 运行测试并生成覆盖率报告
pytest --cov=app --cov-report=html --cov-report=term-missing

# 只运行特定模块的测试
pytest tests/test_models.py --cov=app.models

# 设置覆盖率阈值
pytest --cov=app --cov-fail-under=80

# 生成XML报告(用于CI/CD集成)
pytest --cov=app --cov-report=xml

2. 覆盖率配置文件

# .coveragerc
[run]
source = app
omit =
    */tests/*
    */migrations/*
    */__pycache__/*
    */venv/*
    */env/*

[report]
# 覆盖率阈值
fail_under = 80

# 排除行
exclude_lines =
    # 不要检查pragma注释
    pragma: no cover
    # 不要检查pass语句
    ^\s*pass\s*$
    # 不要检查raise NotImplementedError
    ^\s*raise NotImplementedError\s*$
    # 不要检查抽象方法
    @abstractmethod

# 显示哪些行未覆盖
show_missing = True

# 排序方式
sort = Cover

[html]
# HTML报告目录
directory = htmlcov

# 显示哪些分支未覆盖
show_contexts = True

3. 解读覆盖率报告

覆盖率 等级 建议
90%+ 优秀 继续保持,关注边缘情况
80-89% 良好 覆盖主要功能,可进一步提升
70-79% 一般 需要增加测试,特别是关键路径
<70% 不足 急需增加测试,存在风险

集成测试与CI/CD

1. GitHub Actions配置

# .github/workflows/test.yml
name: Tests

on:
  push:
    branches: [ main, develop ]
  pull_request:
    branches: [ main ]

jobs:
  test:
    runs-on: ubuntu-latest

    services:
      postgres:
        image: postgres:13
        env:
          POSTGRES_PASSWORD: postgres
        options: >-
          --health-cmd pg_isready
          --health-interval 10s
          --health-timeout 5s
          --health-retries 5
        ports:
          - 5432:5432

    strategy:
      matrix:
        python-version: [3.8, 3.9, "3.10"]

    steps:
    - uses: actions/checkout@v2

    - name: Set up Python {{ matrix.python-version }}
      uses: actions/setup-python@v2
      with:
        python-version: {{ matrix.python-version }}

    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install -r requirements.txt
        pip install -r requirements-test.txt

    - name: Run linting
      run: |
        pip install flake8 black isort
        flake8 app tests
        black --check app tests
        isort --check-only app tests

    - name: Run tests with pytest
      env:
        DATABASE_URL: postgresql://postgres:postgres@localhost:5432/test_db
        SECRET_KEY: ${{ secrets.TEST_SECRET_KEY }}
      run: |
        pytest --cov=app --cov-report=xml --cov-report=term-missing

    - name: Upload coverage to Codecov
      uses: codecov/codecov-action@v2
      with:
        file: ./coverage.xml
        flags: unittests
        name: codecov-umbrella

    - name: Check coverage threshold
      run: |
        coverage report --fail-under=80

2. Docker测试配置

# Dockerfile.test
FROM python:3.9-slim

WORKDIR /app

# 安装系统依赖
RUN apt-get update && apt-get install -y \
    gcc \
    postgresql-client \
    && rm -rf /var/lib/apt/lists/*

# 复制依赖文件
COPY requirements.txt requirements-test.txt ./

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt -r requirements-test.txt

# 复制应用代码
COPY . .

# 运行测试
CMD ["pytest", "-v", "--cov=app", "--cov-report=term-missing"]
# docker-compose.test.yml
version: '3.8'

services:
  db:
    image: postgres:13
    environment:
      POSTGRES_USER: postgres
      POSTGRES_PASSWORD: postgres
      POSTGRES_DB: test_db
    ports:
      - "5432:5432"
    healthcheck:
      test: ["CMD-SHELL", "pg_isready -U postgres"]
      interval: 10s
      timeout: 5s
      retries: 5

  test:
    build:
      context: .
      dockerfile: Dockerfile.test
    depends_on:
      db:
        condition: service_healthy
    environment:
      DATABASE_URL: postgresql://postgres:postgres@db:5432/test_db
      SECRET_KEY: test-secret-key
    volumes:
      - .:/app
    command: pytest -v --cov=app --cov-report=html

测试最佳实践

DOs - 应该做的
  • 测试名称清晰:使用描述性的测试名称
  • 独立测试:每个测试应该独立运行
  • 快速执行:测试应该快速完成
  • 覆盖边界情况:测试异常和边界条件
  • 使用fixture:重用测试设置代码
  • 参数化测试:测试多组输入数据
DON'Ts - 避免做的
  • 不要测试实现细节:测试行为,而不是实现
  • 避免测试私有方法:通过公共接口测试
  • 不要依赖测试顺序:测试应该可以任何顺序运行
  • 避免过度的模拟:只在必要时使用mock
  • 不要忽略失败测试:及时修复失败的测试
  • 避免慢速测试:不要测试外部服务(除非是集成测试)

测试命名约定

# 好的测试名称
def test_user_creation_success(): ...
def test_login_with_invalid_credentials(): ...
def test_get_user_by_id_returns_correct_user(): ...
def test_update_profile_with_empty_name_fails(): ...

# 使用描述性的测试类名
class TestUserAuthentication:
    def test_valid_login_creates_session(self): ...
    def test_invalid_login_returns_error(self): ...

# 参数化测试
@pytest.mark.parametrize('input,expected', [
    ('admin', True),
    ('user', False),
    ('', False),
])
def test_is_admin_role(input, expected): ...

常见问题

有几种策略可以隔离测试数据:

  1. 内存数据库:使用SQLite内存数据库
  2. 测试数据库:创建专门的测试数据库
  3. 事务回滚:在每个测试后回滚事务
  4. 数据库fixture:使用pytest fixture管理数据库状态
# 使用事务回滚
@pytest.fixture
def db_session(app):
    """在每个测试后回滚数据库"""
    with app.app_context():
        connection = db.engine.connect()
        transaction = connection.begin()

        # 创建所有表
        db.create_all()

        yield db

        # 测试后回滚
        transaction.rollback()
        connection.close()
        db.session.remove()

# 使用临时数据库
@pytest.fixture(scope='session')
def test_db():
    """创建临时测试数据库"""
    import tempfile
    import os

    # 创建临时数据库文件
    db_fd, db_path = tempfile.mkstemp(suffix='.db')

    # 使用临时数据库路径
    app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{db_path}'

    yield

    # 清理临时文件
    os.close(db_fd)
    os.unlink(db_path)

几种测试需要认证功能的方法:

# 方法1:使用fixture创建已登录用户
@pytest.fixture
def authenticated_client(client, db_session):
    """创建已认证的测试客户端"""
    from app.models import User

    # 创建用户
    user = User(username='testuser', email='test@example.com')
    user.set_password('password123')
    db_session.session.add(user)
    db_session.session.commit()

    # 登录
    client.post('/login', data={
        'username': 'testuser',
        'password': 'password123'
    })

    return client

# 方法2:直接设置session
@pytest.fixture
def logged_in_user(client, db_session):
    """通过设置session模拟登录状态"""
    from app.models import User
    from flask import session

    user = User(username='testuser', email='test@example.com')
    db_session.session.add(user)
    db_session.session.commit()

    with client.session_transaction() as sess:
        sess['user_id'] = user.id

    return user

# 方法3:使用装饰器
def login_user(client, username, password):
    """登录用户的辅助函数"""
    return client.post('/login', data={
        'username': username,
        'password': password
    }, follow_redirects=True)

# 在测试中使用
def test_protected_page(authenticated_client):
    response = authenticated_client.get('/dashboard')
    assert response.status_code == 200
    assert b'Dashboard' in response.data

# 测试不同权限的用户
@pytest.mark.parametrize('role,expected_status', [
    ('admin', 200),
    ('user', 200),
    ('guest', 403),
])
def test_admin_page_access(client, role, expected_status):
    # 创建不同角色的用户并登录
    user = create_user_with_role(role)
    login_user(client, user.username, 'password')

    response = client.get('/admin')
    assert response.status_code == expected_status

处理慢速测试的策略:

  1. 模拟外部调用:使用unittest.mock
  2. 标记慢速测试:使用pytest标记
  3. 使用测试双倍:Stub、Mock、Fake
  4. 异步执行:使用异步测试
# 标记慢速测试
import pytest

@pytest.mark.slow
def test_external_api_integration():
    """与真实外部API交互的测试"""
    # 实际调用外部API
    result = call_external_api()
    assert result is not None

# 运行测试时排除慢速测试
# pytest -m "not slow"

# 使用模拟
from unittest.mock import Mock, patch

def test_external_api_with_mock():
    """使用模拟测试外部API逻辑"""
    mock_response = Mock()
    mock_response.status_code = 200
    mock_response.json.return_value = {'data': 'test'}

    with patch('requests.get', return_value=mock_response):
        result = get_external_data()
        assert result == {'data': 'test'}

# 使用测试双倍(Test Double)
class FakeExternalService:
    """外部服务的假实现"""
    def get_data(self):
        return {'data': 'fake data'}

def test_with_fake_service():
    """使用假服务测试"""
    service = FakeExternalService()
    result = service.get_data()
    assert result['data'] == 'fake data'

# 设置超时
import pytest
import time

@pytest.mark.timeout(5)  # 5秒超时
def test_with_timeout():
    # 如果测试超过5秒,自动失败
    time.sleep(10)  # 这会触发超时失败

测试异步Flask应用的方法:

# 安装异步测试支持
# pip install pytest-asyncio

import pytest
import asyncio
from flask import Flask

# 创建异步Flask应用
app = Flask(__name__)

@app.route('/async')
async def async_endpoint():
    await asyncio.sleep(0.1)
    return 'Async Response'

# 测试异步端点
@pytest.mark.asyncio
async def test_async_endpoint():
    """测试异步端点"""
    test_client = app.test_client()

    # 使用async/await
    response = await test_client.get('/async')

    assert response.status_code == 200
    assert await response.get_data(as_text=True) == 'Async Response'

# 测试协程函数
async def async_function():
    await asyncio.sleep(0.1)
    return 'result'

@pytest.mark.asyncio
async def test_async_function():
    result = await async_function()
    assert result == 'result'

# 使用事件循环fixture
@pytest.fixture
def event_loop():
    """创建事件循环fixture"""
    loop = asyncio.new_event_loop()
    yield loop
    loop.close()

# 测试异步数据库操作
@pytest.mark.asyncio
async def test_async_database(db_session):
    """测试异步数据库操作"""
    from app.models import User

    # 异步创建用户
    user = User(username='asyncuser', email='async@example.com')
    db_session.session.add(user)

    # 异步提交
    await asyncio.to_thread(db_session.session.commit)

    # 异步查询
    users = await asyncio.to_thread(
        lambda: User.query.filter_by(username='asyncuser').first()
    )

    assert users is not None
    assert users.username == 'asyncuser'

# 使用httpx测试异步HTTP客户端
import httpx
import pytest

@pytest.mark.asyncio
async def test_external_api_async():
    """异步测试外部API"""
    async with httpx.AsyncClient() as client:
        response = await client.get('https://api.example.com/data')

        assert response.status_code == 200
        data = response.json()
        assert 'data' in data

测试驱动开发(TDD)流程

  1. 红:编写一个失败的测试
  2. 绿:编写最少代码使测试通过
  3. 重构:改进代码结构,保持测试通过
  4. 重复:为下一个功能重复此过程

测试资源与工具

推荐阅读
  • 《Python Testing with pytest》
  • 《Test-Driven Development with Python》
  • 《The Art of Unit Testing》
实用工具
  • pytest-cov:测试覆盖率
  • pytest-mock:模拟对象
  • factory-boy:测试数据工厂
  • faker:生成假数据