Flask用户授权:使用装饰器保护视图函数

授权是指控制用户能访问哪些资源,而装饰器是Python中实现授权控制的优雅方式。

1. 授权基础概念

认证 vs 授权
认证 (Authentication) 授权 (Authorization)
验证用户身份 验证用户权限
回答"你是谁?" 回答"你能做什么?"
登录/注册 访问控制
权限模型
  • 基于角色 (RBAC): 给角色分配权限
  • 基于权限 (PBAC): 直接给用户分配权限
  • 基于属性 (ABAC): 根据属性动态决策

2. 装饰器基础

2.1 Python装饰器回顾

# 基本装饰器示例
def simple_decorator(func):
    """一个简单的装饰器"""
    def wrapper(*args, **kwargs):
        print("函数执行前")
        result = func(*args, **kwargs)
        print("函数执行后")
        return result
    return wrapper

@simple_decorator
def hello():
    print("Hello, World!")

hello()
# 输出:
# 函数执行前
# Hello, World!
# 函数执行后

# 带参数的装饰器
def repeat(times):
    """重复执行函数的装饰器"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            for _ in range(times):
                result = func(*args, **kwargs)
            return result
        return wrapper
    return decorator

@repeat(3)
def say_hello(name):
    print(f"Hello, {name}!")

say_hello("Alice")
# 输出:
# Hello, Alice!
# Hello, Alice!
# Hello, Alice!

2.2 保留函数元信息

from functools import wraps

def preserve_metadata(func):
    """保留函数元信息的装饰器"""
    @wraps(func)  # 使用wraps保留原函数的元信息
    def wrapper(*args, **kwargs):
        """包装函数"""
        return func(*args, **kwargs)
    return wrapper

@preserve_metadata
def example():
    """示例函数"""
    pass

print(example.__name__)  # 输出: example
print(example.__doc__)   # 输出: 示例函数

3. 基本授权装饰器

3.1 登录要求装饰器

from flask import flash, redirect, url_for, request
from flask_login import current_user
from functools import wraps

def login_required(func):
    """要求用户登录的装饰器"""
    @wraps(func)
    def decorated_view(*args, **kwargs):
        if not current_user.is_authenticated:
            flash('请先登录以访问此页面', 'warning')

            # 记录用户原本要访问的页面
            next_url = request.url

            # 重定向到登录页面,并传递next参数
            return redirect(url_for('auth.login', next=next_url))

        return func(*args, **kwargs)
    return decorated_view

# 使用示例
@app.route('/dashboard')
@login_required
def dashboard():
    """用户仪表板(需要登录)"""
    return render_template('dashboard.html')

@app.route('/profile')
@login_required
def profile():
    """个人资料页面(需要登录)"""
    return render_template('profile.html')

3.2 管理员权限装饰器

def admin_required(func):
    """要求管理员权限的装饰器"""
    @wraps(func)
    def decorated_view(*args, **kwargs):
        if not current_user.is_authenticated:
            flash('请先登录', 'warning')
            return redirect(url_for('auth.login', next=request.url))

        if not current_user.is_admin:
            flash('您没有权限访问此页面', 'danger')
            return redirect(url_for('main.index'))

        return func(*args, **kwargs)
    return decorated_view

# 使用示例
@app.route('/admin/dashboard')
@login_required
@admin_required  # 可以组合使用多个装饰器
def admin_dashboard():
    """管理员仪表板"""
    return render_template('admin/dashboard.html')

@app.route('/admin/users')
@admin_required
def manage_users():
    """管理用户(需要管理员权限)"""
    return render_template('admin/users.html')

4. 基于角色的授权系统

4.1 角色模型设计

from datetime import datetime

# 角色模型
class Role(db.Model):
    """角色模型"""
    __tablename__ = 'roles'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(64), unique=True, nullable=False)
    description = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)

    # 关系:角色和用户(多对多)
    users = db.relationship('User', secondary='user_roles', backref='roles', lazy='dynamic')
    # 关系:角色和权限(多对多)
    permissions = db.relationship('Permission', secondary='role_permissions', backref='roles', lazy='dynamic')

    def __repr__(self):
        return f'<Role {self.name}>'

    @staticmethod
    def insert_roles():
        """插入默认角色"""
        roles = {
            'user': '普通用户',
            'author': '作者',
            'editor': '编辑',
            'moderator': '版主',
            'admin': '管理员'
        }

        for name, description in roles.items():
            role = Role.query.filter_by(name=name).first()
            if role is None:
                role = Role(name=name, description=description)
                db.session.add(role)
        db.session.commit()

# 权限模型
class Permission(db.Model):
    """权限模型"""
    __tablename__ = 'permissions'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(64), unique=True, nullable=False)
    code = db.Column(db.String(64), unique=True, nullable=False)  # 权限代码
    description = db.Column(db.Text)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)

    def __repr__(self):
        return f'<Permission {self.name}>'

# 用户-角色关联表
user_roles = db.Table('user_roles',
    db.Column('user_id', db.Integer, db.ForeignKey('users.id'), primary_key=True),
    db.Column('role_id', db.Integer, db.ForeignKey('roles.id'), primary_key=True),
    db.Column('assigned_at', db.DateTime, default=datetime.utcnow)
)

# 角色-权限关联表
role_permissions = db.Table('role_permissions',
    db.Column('role_id', db.Integer, db.ForeignKey('roles.id'), primary_key=True),
    db.Column('permission_id', db.Integer, db.ForeignKey('permissions.id'), primary_key=True),
    db.Column('assigned_at', db.DateTime, default=datetime.utcnow)
)

# 扩展用户模型
class User(UserMixin, db.Model):
    """用户模型(扩展)"""
    __tablename__ = 'users'

    # ... 其他字段

    # 检查权限的方法
    def has_role(self, role_name):
        """检查用户是否有某个角色"""
        return self.roles.filter_by(name=role_name).first() is not None

    def has_permission(self, permission_code):
        """检查用户是否有某个权限"""
        # 遍历用户的所有角色,检查是否有该权限
        for role in self.roles:
            if role.permissions.filter_by(code=permission_code).first():
                return True
        return False

    def add_role(self, role_name):
        """给用户添加角色"""
        role = Role.query.filter_by(name=role_name).first()
        if role and role not in self.roles:
            self.roles.append(role)
            db.session.commit()

    def remove_role(self, role_name):
        """移除用户的角色"""
        role = Role.query.filter_by(name=role_name).first()
        if role and role in self.roles:
            self.roles.remove(role)
            db.session.commit()

4.2 权限装饰器

def permission_required(permission_code):
    """要求特定权限的装饰器"""
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                flash('请先登录', 'warning')
                return redirect(url_for('auth.login', next=request.url))

            if not current_user.has_permission(permission_code):
                flash('您没有执行此操作的权限', 'danger')
                return redirect(url_for('main.index'))

            return func(*args, **kwargs)
        return decorated_view
    return decorator

def role_required(role_name):
    """要求特定角色的装饰器"""
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                flash('请先登录', 'warning')
                return redirect(url_for('auth.login', next=request.url))

            if not current_user.has_role(role_name):
                flash(f'需要 {role_name} 角色才能访问此页面', 'danger')
                return redirect(url_for('main.index'))

            return func(*args, **kwargs)
        return decorated_view
    return decorator

# 使用示例
@app.route('/article/create')
@login_required
@role_required('author')  # 需要作者角色
def create_article():
    """创建文章(需要作者角色)"""
    return render_template('article/create.html')

@app.route('/article/<int:article_id>/delete')
@login_required
@permission_required('article.delete')  # 需要删除文章的权限
def delete_article(article_id):
    """删除文章(需要特定权限)"""
    article = Article.query.get_or_404(article_id)
    db.session.delete(article)
    db.session.commit()
    flash('文章已删除', 'success')
    return redirect(url_for('article.index'))

5. 进阶装饰器模式

5.1 装饰器组合器

from functools import reduce

def combine_decorators(*decorators):
    """组合多个装饰器"""
    def decorator(func):
        return reduce(lambda f, d: d(f), reversed(decorators), func)
    return decorator

# 定义权限组合
author_required = combine_decorators(login_required, role_required('author'))
admin_required = combine_decorators(login_required, role_required('admin'))

# 使用组合装饰器
@app.route('/author/dashboard')
@author_required
def author_dashboard():
    """作者仪表板"""
    return render_template('author/dashboard.html')

# 等效于
# @app.route('/author/dashboard')
# @login_required
# @role_required('author')
# def author_dashboard():
#     ...

5.2 动态权限装饰器

def dynamic_permission_required(permission_resolver):
    """
    动态权限装饰器
    permission_resolver: 函数,接收视图函数的参数,返回权限代码
    """
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                flash('请先登录', 'warning')
                return redirect(url_for('auth.login', next=request.url))

            # 动态解析需要的权限
            required_permission = permission_resolver(*args, **kwargs)

            if not current_user.has_permission(required_permission):
                flash('您没有执行此操作的权限', 'danger')
                return redirect(url_for('main.index'))

            return func(*args, **kwargs)
        return decorated_view
    return decorator

# 使用示例
def resolve_article_permission(article_id):
    """根据文章ID解析需要的权限"""
    article = Article.query.get(article_id)
    if article and article.is_published:
        return 'article.edit.published'
    else:
        return 'article.edit.draft'

@app.route('/article/<int:article_id>/edit')
@login_required
@dynamic_permission_required(resolve_article_permission)
def edit_article(article_id):
    """编辑文章(动态权限)"""
    article = Article.query.get_or_404(article_id)
    return render_template('article/edit.html', article=article)

5.3 基于资源的授权装饰器

def resource_permission_required(model_class, action, id_param='id'):
    """
    基于资源的权限装饰器
    model_class: 数据模型类
    action: 操作类型(view、edit、delete等)
    id_param: URL参数名,默认为'id'
    """
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                flash('请先登录', 'warning')
                return redirect(url_for('auth.login', next=request.url))

            # 获取资源ID
            resource_id = kwargs.get(id_param)

            if resource_id:
                # 获取资源
                resource = model_class.query.get(resource_id)

                if not resource:
                    abort(404)

                # 检查用户是否有权限操作该资源
                if not can_user_access_resource(current_user, resource, action):
                    flash('您没有权限操作此资源', 'danger')
                    return redirect(url_for('main.index'))

            return func(*args, **kwargs)
        return decorated_view
    return decorator

def can_user_access_resource(user, resource, action):
    """检查用户是否可以操作资源"""
    # 如果是资源所有者,允许所有操作
    if hasattr(resource, 'user_id') and resource.user_id == user.id:
        return True

    # 检查角色权限
    if action == 'view':
        return user.has_permission(f'{resource.__tablename__}.view')
    elif action == 'edit':
        return user.has_permission(f'{resource.__tablename__}.edit')
    elif action == 'delete':
        return user.has_permission(f'{resource.__tablename__}.delete')

    return False

# 使用示例
@app.route('/article/<int:id>')
@resource_permission_required(Article, 'view')
def view_article(id):
    """查看文章(需要查看权限)"""
    article = Article.query.get_or_404(id)
    return render_template('article/view.html', article=article)

@app.route('/article/<int:id>/delete', methods=['POST'])
@resource_permission_required(Article, 'delete')
def delete_article_resource(id):
    """删除文章(需要删除权限)"""
    article = Article.query.get_or_404(id)
    db.session.delete(article)
    db.session.commit()
    flash('文章已删除', 'success')
    return redirect(url_for('article.index'))

6. Flask-Login集成

from flask_login import LoginManager, UserMixin, login_user, logout_user, login_required

# 配置Flask-Login
login_manager = LoginManager()
login_manager.login_view = 'auth.login'
login_manager.login_message = '请先登录'
login_manager.login_message_category = 'warning'

@login_manager.user_loader
def load_user(user_id):
    """加载用户"""
    return User.query.get(int(user_id))

# 自定义的login_required装饰器,扩展Flask-Login的功能
def extended_login_required(func=None, fresh=False, redirect_to_login=True):
    """
    扩展的登录要求装饰器

    Args:
        func: 被装饰的函数
        fresh: 是否需要"新鲜"的登录(重新输入密码)
        redirect_to_login: 是否重定向到登录页面
    """
    def decorator(f):
        @wraps(f)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                if redirect_to_login:
                    flash('请先登录', 'warning')
                    return redirect(url_for(login_manager.login_view, next=request.url))
                else:
                    abort(401)  # 未授权

            if fresh and not login_fresh():
                return login_manager.unauthorized()

            return f(*args, **kwargs)
        return decorated_view

    if func:
        return decorator(func)
    return decorator

# 使用示例
@app.route('/settings/security')
@extended_login_required(fresh=True)  # 需要重新输入密码
def security_settings():
    """安全设置页面(需要重新验证)"""
    return render_template('settings/security.html')

@app.route('/api/user')
@extended_login_required(redirect_to_login=False)  # API接口,返回401而不是重定向
def api_get_user():
    """API:获取用户信息"""
    return jsonify(current_user.to_dict())

7. 完整的授权装饰器库

# app/decorators.py
from functools import wraps
from flask import flash, redirect, url_for, request, abort, current_app
from flask_login import current_user, login_fresh
import logging

logger = logging.getLogger(__name__)

class AuthorizationDecorators:
    """授权装饰器集合"""

    @staticmethod
    def login_required(func=None, message=None, category='warning'):
        """登录要求装饰器"""
        def decorator(f):
            @wraps(f)
            def decorated_view(*args, **kwargs):
                if not current_user.is_authenticated:
                    if message:
                        flash(message, category)

                    # 记录访问日志
                    logger.warning(f'未授权访问: {request.url}')

                    return redirect(url_for(
                        current_app.login_manager.login_view,
                        next=request.url
                    ))

                return f(*args, **kwargs)
            return decorated_view

        if func:
            return decorator(func)
        return decorator

    @staticmethod
    def role_required(*roles):
        """角色要求装饰器(支持多个角色)"""
        def decorator(func):
            @wraps(func)
            def decorated_view(*args, **kwargs):
                if not current_user.is_authenticated:
                    flash('请先登录', 'warning')
                    return redirect(url_for('auth.login', next=request.url))

                # 检查用户是否有任一所需角色
                user_roles = {role.name for role in current_user.roles}
                if not any(role in user_roles for role in roles):
                    role_names = ', '.join(roles)
                    flash(f'需要以下角色之一: {role_names}', 'danger')
                    return redirect(url_for('main.index'))

                return func(*args, **kwargs)
            return decorated_view
        return decorator

    @staticmethod
    def permission_required(*permissions):
        """权限要求装饰器(支持多个权限)"""
        def decorator(func):
            @wraps(func)
            def decorated_view(*args, **kwargs):
                if not current_user.is_authenticated:
                    flash('请先登录', 'warning')
                    return redirect(url_for('auth.login', next=request.url))

                # 检查用户是否有任一所需权限
                has_any = False
                for perm in permissions:
                    if current_user.has_permission(perm):
                        has_any = True
                        break

                if not has_any:
                    perm_names = ', '.join(permissions)
                    flash(f'需要以下权限之一: {perm_names}', 'danger')

                    # 记录权限不足的访问
                    logger.warning(
                        f'权限不足: user={current_user.username}, '
                        f'path={request.path}, required={permissions}'
                    )

                    return redirect(url_for('main.index'))

                return func(*args, **kwargs)
            return decorated_view
        return decorator

    @staticmethod
    def confirm_required(func=None, message='请确认您的操作'):
        """操作确认装饰器"""
        def decorator(f):
            @wraps(f)
            def decorated_view(*args, **kwargs):
                confirm = request.args.get('confirm') == 'true'

                if not confirm:
                    # 返回确认页面
                    return render_template('confirm.html',
                                         message=message,
                                         next_url=request.url)

                return f(*args, **kwargs)
            return decorated_view

        if func:
            return decorator(func)
        return decorator

    @staticmethod
    def rate_limit(max_requests=10, window=60):
        """限流装饰器(防止暴力请求)"""
        from datetime import datetime, timedelta
        from collections import defaultdict

        request_log = defaultdict(list)

        def decorator(func):
            @wraps(func)
            def decorated_view(*args, **kwargs):
                # 获取客户端标识
                client_id = request.remote_addr

                # 清理过期记录
                now = datetime.now()
                request_log[client_id] = [
                    time for time in request_log[client_id]
                    if now - time < timedelta(seconds=window)
                ]

                # 检查是否超过限制
                if len(request_log[client_id]) >= max_requests:
                    flash('请求过于频繁,请稍后再试', 'warning')
                    return redirect(url_for('main.index'))

                # 记录本次请求
                request_log[client_id].append(now)

                return func(*args, **kwargs)
            return decorated_view
        return decorator

    @staticmethod
    def audit_log(action=None):
        """审计日志装饰器"""
        def decorator(func):
            @wraps(func)
            def decorated_view(*args, **kwargs):
                # 执行函数
                result = func(*args, **kwargs)

                # 记录审计日志
                if current_user.is_authenticated:
                    log_entry = AuditLog(
                        user_id=current_user.id,
                        action=action or func.__name__,
                        ip_address=request.remote_addr,
                        user_agent=request.user_agent.string,
                        path=request.path,
                        method=request.method,
                        status_code=200  # 假设成功
                    )
                    db.session.add(log_entry)
                    db.session.commit()

                return result
            return decorated_view
        return decorator

# 创建便捷的别名
login_required = AuthorizationDecorators.login_required
role_required = AuthorizationDecorators.role_required
permission_required = AuthorizationDecorators.permission_required
confirm_required = AuthorizationDecorators.confirm_required
rate_limit = AuthorizationDecorators.rate_limit
audit_log = AuthorizationDecorators.audit_log

# 使用示例
@app.route('/admin/delete_user/<int:user_id>')
@login_required(message='需要管理员权限')
@role_required('admin')
@confirm_required(message='确定要删除此用户吗?')
@audit_log(action='delete_user')
def delete_user(user_id):
    """删除用户(需要多重验证)"""
    user = User.query.get_or_404(user_id)
    db.session.delete(user)
    db.session.commit()
    flash('用户已删除', 'success')
    return redirect(url_for('admin.users'))

8. 最佳实践和常见模式

  • 使用@wraps装饰器:保留函数的元信息
  • 明确错误信息:告诉用户为什么被拒绝访问
  • 记录访问日志:监控未授权访问尝试
  • 防御CSRF攻击:确保所有修改操作都有CSRF保护
  • 最小权限原则:只授予必要的权限
  • 定期审查权限:定期检查和清理不再需要的权限
  • 测试授权逻辑:编写测试确保授权系统正确工作

9. 测试授权系统

# tests/test_authorization.py
import unittest
from flask import url_for
from app import create_app, db
from app.models import User, Role, Permission
from config import TestingConfig

class AuthorizationTestCase(unittest.TestCase):
    def setUp(self):
        """测试前准备"""
        self.app = create_app(TestingConfig)
        self.app_context = self.app.app_context()
        self.app_context.push()
        db.create_all()
        self.client = self.app.test_client()

        # 创建测试数据
        self.create_test_data()

    def tearDown(self):
        """测试后清理"""
        db.session.remove()
        db.drop_all()
        self.app_context.pop()

    def create_test_data(self):
        """创建测试数据"""
        # 创建角色
        admin_role = Role(name='admin', description='管理员')
        user_role = Role(name='user', description='普通用户')

        # 创建权限
        view_permission = Permission(name='查看文章', code='article.view')
        edit_permission = Permission(name='编辑文章', code='article.edit')
        delete_permission = Permission(name='删除文章', code='article.delete')

        # 关联角色和权限
        admin_role.permissions.extend([view_permission, edit_permission, delete_permission])
        user_role.permissions.append(view_permission)

        # 创建用户
        admin_user = User(username='admin', email='admin@test.com', password='admin123')
        normal_user = User(username='user', email='user@test.com', password='user123')

        # 分配角色
        admin_user.roles.append(admin_role)
        normal_user.roles.append(user_role)

        db.session.add_all([admin_role, user_role, view_permission,
                           edit_permission, delete_permission,
                           admin_user, normal_user])
        db.session.commit()

        self.admin_user = admin_user
        self.normal_user = normal_user

    def test_role_based_access(self):
        """测试基于角色的访问控制"""
        # 普通用户登录
        self.client.post(url_for('auth.login'), data={
            'username': 'user',
            'password': 'user123'
        })

        # 普通用户尝试访问管理员页面
        response = self.client.get(url_for('admin.dashboard'), follow_redirects=True)
        self.assertIn(b'没有权限访问此页面', response.data)

    def test_permission_based_access(self):
        """测试基于权限的访问控制"""
        # 普通用户登录
        self.client.post(url_for('auth.login'), data={
            'username': 'user',
            'password': 'user123'
        })

        # 普通用户尝试删除文章(没有删除权限)
        response = self.client.post(url_for('article.delete', article_id=1), follow_redirects=True)
        self.assertIn(b'没有执行此操作的权限', response.data)

    def test_decorator_combination(self):
        """测试装饰器组合"""
        # 未登录用户访问需要登录的页面
        response = self.client.get(url_for('dashboard'), follow_redirects=True)
        self.assertIn(b'请先登录', response.data)

    def test_api_authentication(self):
        """测试API认证"""
        # 未认证的API请求
        response = self.client.get(url_for('api.user'))
        self.assertEqual(response.status_code, 401)  # 未授权

if __name__ == '__main__':
    unittest.main()

10. 性能考虑

# 优化权限检查
class User(UserMixin, db.Model):
    # ... 其他字段

    # 缓存用户的权限,避免每次检查都查询数据库
    _permissions_cache = None

    @property
    def cached_permissions(self):
        """缓存的权限列表"""
        if self._permissions_cache is None:
            self._permissions_cache = set()
            for role in self.roles:
                for permission in role.permissions:
                    self._permissions_cache.add(permission.code)
        return self._permissions_cache

    def has_permission_cached(self, permission_code):
        """使用缓存的权限检查"""
        return permission_code in self.cached_permissions

    def clear_permissions_cache(self):
        """清空权限缓存"""
        self._permissions_cache = None

    def refresh_permissions_cache(self):
        """刷新权限缓存"""
        self.clear_permissions_cache()
        return self.cached_permissions

# 使用缓存进行权限检查的装饰器
def permission_required_cached(permission_code):
    """使用缓存检查权限的装饰器"""
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            if not current_user.is_authenticated:
                flash('请先登录', 'warning')
                return redirect(url_for('auth.login', next=request.url))

            if not current_user.has_permission_cached(permission_code):
                flash('您没有执行此操作的权限', 'danger')
                return redirect(url_for('main.index'))

            return func(*args, **kwargs)
        return decorated_view
    return decorator

# 批量的权限检查装饰器
def check_permissions_deferred(*permission_codes):
    """
    延迟的权限检查装饰器
    在视图函数中手动调用权限检查
    """
    def decorator(func):
        @wraps(func)
        def decorated_view(*args, **kwargs):
            # 将权限检查延迟到视图函数中
            def check_permissions():
                for perm in permission_codes:
                    if not current_user.has_permission(perm):
                        return False, perm
                return True, None

            # 将检查函数注入到视图函数的关键字参数中
            kwargs['_check_permissions'] = check_permissions

            return func(*args, **kwargs)
        return decorated_view
    return decorator

# 使用示例
@app.route('/complex/operation')
@login_required
@check_permissions_deferred('resource.view', 'resource.edit', 'resource.approve')
def complex_operation(_check_permissions=None):
    """复杂的操作,需要多个权限"""
    # 手动检查权限
    has_access, missing_perm = _check_permissions()

    if not has_access:
        flash(f'缺少权限: {missing_perm}', 'danger')
        return redirect(url_for('main.index'))

    # 执行操作
    return render_template('complex/operation.html')