Flask数据库基本操作

数据库基本操作包括数据的增删改查(CRUD),这是Web应用开发中最常用的操作。

1. 准备工作:数据模型定义

首先定义一些示例模型,用于演示数据库操作:

from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

db = SQLAlchemy(app)

class User(db.Model):
    """用户模型"""
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)
    age = db.Column(db.Integer)
    is_active = db.Column(db.Boolean, default=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

    # 一对多关系:一个用户有多篇文章
    posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')

    def __repr__(self):
        return f'<User {self.username}>'

class Post(db.Model):
    """文章模型"""
    __tablename__ = 'posts'

    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(200), nullable=False)
    content = db.Column(db.Text, nullable=False)
    views = db.Column(db.Integer, default=0)
    is_published = db.Column(db.Boolean, default=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)

    # 外键:作者ID
    author_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)

    def __repr__(self):
        return f'<Post {self.title}>'

class Category(db.Model):
    """分类模型"""
    __tablename__ = 'categories'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), unique=True, nullable=False)
    description = db.Column(db.Text)

    # 多对多关系:文章和分类
    posts = db.relationship('Post', secondary='post_categories', backref='categories')

    def __repr__(self):
        return f'<Category {self.name}>'

# 关联表:文章分类
post_categories = db.Table('post_categories',
    db.Column('post_id', db.Integer, db.ForeignKey('posts.id'), primary_key=True),
    db.Column('category_id', db.Integer, db.ForeignKey('categories.id'), primary_key=True),
    db.Column('created_at', db.Column(db.DateTime, default=datetime.utcnow))
)

2. 创建数据(Create)

2.1 添加单条记录

# 方法1:使用模型构造函数创建对象
def create_single_record():
    """创建单条记录"""
    # 创建用户对象
    user = User(
        username='john_doe',
        email='john@example.com',
        age=25,
        is_active=True
    )

    # 添加到会话
    db.session.add(user)

    # 提交到数据库
    try:
        db.session.commit()
        print(f"用户创建成功: {user.username}")
        return user
    except Exception as e:
        db.session.rollback()
        print(f"创建失败: {e}")
        return None

# 方法2:直接使用session.add()并提交
def create_user_simple():
    """简化版创建用户"""
    user = User(username='jane_doe', email='jane@example.com')
    db.session.add(user)
    db.session.commit()
    return user

# 方法3:使用create()方法(如果模型中有定义)
class User(db.Model):
    # ... 其他字段定义

    @classmethod
    def create(cls, **kwargs):
        """创建用户的便捷方法"""
        user = cls(**kwargs)
        db.session.add(user)
        db.session.commit()
        return user

# 使用便捷方法
user = User.create(username='bob', email='bob@example.com')

2.2 批量添加记录

def create_bulk_records():
    """批量创建记录"""
    # 准备用户数据列表
    users_data = [
        {'username': 'alice', 'email': 'alice@example.com', 'age': 28},
        {'username': 'bob', 'email': 'bob@example.com', 'age': 32},
        {'username': 'charlie', 'email': 'charlie@example.com', 'age': 24},
        {'username': 'david', 'email': 'david@example.com', 'age': 35},
        {'username': 'eve', 'email': 'eve@example.com', 'age': 29},
    ]

    # 方法1:使用循环逐个添加
    users = []
    for data in users_data:
        user = User(**data)
        db.session.add(user)
        users.append(user)

    # 方法2:使用add_all()一次性添加
    # users = [User(**data) for data in users_data]
    # db.session.add_all(users)

    try:
        db.session.commit()
        print(f"批量创建成功,共创建了 {len(users)} 个用户")
        return users
    except Exception as e:
        db.session.rollback()
        print(f"批量创建失败: {e}")
        return []

def create_with_relationships():
    """创建有关联关系的记录"""
    # 创建用户
    user = User(username='author1', email='author1@example.com')
    db.session.add(user)
    db.session.commit()  # 需要先提交,获取用户ID

    # 创建文章并关联用户
    post1 = Post(
        title='我的第一篇文章',
        content='这是文章内容...',
        author_id=user.id
    )

    post2 = Post(
        title='Flask数据库教程',
        content='学习Flask数据库操作...',
        author_id=user.id
    )

    # 更简单的方式:直接使用关系
    post3 = Post(title='Python入门', content='Python基础教程...')
    user.posts.append(post3)  # 通过关系添加

    # 创建分类
    python_category = Category(name='Python', description='Python相关文章')
    flask_category = Category(name='Flask', description='Flask框架相关文章')

    # 关联文章和分类
    post1.categories.append(python_category)
    post2.categories.extend([python_category, flask_category])

    db.session.add_all([post1, post2, post3, python_category, flask_category])
    db.session.commit()

    print("创建完成,包含关联关系")
    return user

3. 查询数据(Read)

3.1 基本查询方法

def basic_queries():
    """基本查询操作"""

    # 1. 查询所有记录
    all_users = User.query.all()
    print(f"所有用户: {all_users}")

    # 2. 查询第一条记录
    first_user = User.query.first()
    print(f"第一个用户: {first_user}")

    # 3. 通过主键查询
    user_by_id = User.query.get(1)  # 查询ID为1的用户
    print(f"ID为1的用户: {user_by_id}")

    # 4. 查询记录数量
    user_count = User.query.count()
    print(f"用户总数: {user_count}")

    # 5. 检查是否存在
    exists = User.query.filter_by(username='john_doe').first() is not None
    print(f"用户john_doe是否存在: {exists}")

    return all_users

def filter_queries():
    """条件过滤查询"""

    # 1. 等值过滤
    active_users = User.query.filter_by(is_active=True).all()
    print(f"活跃用户: {active_users}")

    # 2. 比较过滤
    adult_users = User.query.filter(User.age >= 18).all()
    print(f"成年用户: {adult_users}")

    # 3. 多条件过滤
    active_adults = User.query.filter(
        User.is_active == True,
        User.age >= 18
    ).all()
    print(f"活跃的成年用户: {active_adults}")

    # 4. LIKE模糊查询
    users_like_john = User.query.filter(User.username.like('%john%')).all()
    print(f"用户名包含john的用户: {users_like_john}")

    # 5. IN查询
    specific_users = User.query.filter(User.username.in_(['john_doe', 'jane_doe'])).all()
    print(f"特定用户名用户: {specific_users}")

    # 6. 范围查询
    users_in_range = User.query.filter(User.age.between(20, 30)).all()
    print(f"年龄在20-30之间的用户: {users_in_range}")

    # 7. NULL值查询
    users_without_age = User.query.filter(User.age == None).all()
    print(f"年龄为空的用户: {users_without_age}")

    return active_users

3.2 排序和限制

def sorting_and_limiting():
    """排序和限制查询"""

    # 1. 升序排序
    users_asc = User.query.order_by(User.username).all()
    print(f"按用户名升序: {users_asc}")

    # 2. 降序排序
    users_desc = User.query.order_by(User.username.desc()).all()
    print(f"按用户名降序: {users_desc}")

    # 3. 多列排序
    users_multi_sort = User.query.order_by(User.is_active.desc(), User.username).all()
    print(f"先按活跃状态降序,再按用户名升序: {users_multi_sort}")

    # 4. 限制结果数量
    recent_users = User.query.order_by(User.created_at.desc()).limit(5).all()
    print(f"最近创建的5个用户: {recent_users}")

    # 5. 偏移查询(分页)
    page_2_users = User.query.order_by(User.id).offset(10).limit(10).all()
    print(f"第2页用户(每页10条): {page_2_users}")

    # 6. 组合:最近活跃的成年用户
    recent_active_adults = User.query.filter(
        User.is_active == True,
        User.age >= 18
    ).order_by(
        User.last_login.desc()
    ).limit(10).all()

    return recent_users

3.3 关联查询

def relationship_queries():
    """关联关系查询"""

    # 1. 查询用户的所有文章
    user = User.query.get(1)
    if user:
        user_posts = user.posts.all()  # 使用关系属性
        print(f"用户 {user.username} 的文章: {user_posts}")

        # 也可以使用查询
        posts_by_user = Post.query.filter_by(author_id=user.id).all()

    # 2. 查询文章的作者
    post = Post.query.first()
    if post:
        author = post.author  # 通过backref访问
        print(f"文章 {post.title} 的作者: {author.username}")

    # 3. 带条件的关联查询
    # 查询所有发表过文章的用户
    users_with_posts = User.query.join(Post).filter(Post.is_published == True).all()
    print(f"发表过文章的用户: {users_with_posts}")

    # 4. 查询文章数量大于N的用户
    from sqlalchemy import func
    active_authors = db.session.query(
        User, func.count(Post.id).label('post_count')
    ).join(
        Post, User.id == Post.author_id
    ).filter(
        Post.is_published == True
    ).group_by(
        User.id
    ).having(
        func.count(Post.id) > 0
    ).all()

    print(f"活跃作者及其文章数量: {active_authors}")

    # 5. 多对多关联查询
    python_posts = Post.query.join(
        post_categories
    ).join(
        Category
    ).filter(
        Category.name == 'Python'
    ).all()

    print(f"Python分类下的文章: {python_posts}")

    return users_with_posts

3.4 聚合查询

def aggregate_queries():
    """聚合查询"""

    from sqlalchemy import func

    # 1. 计数
    total_users = db.session.query(func.count(User.id)).scalar()
    print(f"用户总数: {total_users}")

    # 2. 求和
    total_age = db.session.query(func.sum(User.age)).scalar()
    print(f"用户年龄总和: {total_age}")

    # 3. 平均值
    avg_age = db.session.query(func.avg(User.age)).scalar()
    print(f"平均年龄: {avg_age}")

    # 4. 最大值和最小值
    max_age = db.session.query(func.max(User.age)).scalar()
    min_age = db.session.query(func.min(User.age)).scalar()
    print(f"最大年龄: {max_age}, 最小年龄: {min_age}")

    # 5. 分组统计
    age_groups = db.session.query(
        User.age,
        func.count(User.id).label('count')
    ).filter(
        User.age != None
    ).group_by(
        User.age
    ).order_by(
        User.age
    ).all()

    print("按年龄分组统计:")
    for age, count in age_groups:
        print(f"  年龄 {age}: {count} 人")

    # 6. 复杂聚合:每个用户的文章数量和总浏览量
    user_stats = db.session.query(
        User.username,
        func.count(Post.id).label('post_count'),
        func.sum(Post.views).label('total_views')
    ).outerjoin(
        Post, User.id == Post.author_id
    ).group_by(
        User.id
    ).order_by(
        func.count(Post.id).desc()
    ).all()

    print("用户文章统计:")
    for username, post_count, total_views in user_stats:
        print(f"  {username}: {post_count} 篇文章, {total_views or 0} 次浏览")

    return age_groups

4. 更新数据(Update)

def update_operations():
    """更新操作"""

    # 1. 更新单个字段
    user = User.query.get(1)
    if user:
        user.username = 'new_username'
        db.session.commit()
        print(f"用户名已更新: {user.username}")

    # 2. 更新多个字段
    user = User.query.filter_by(email='john@example.com').first()
    if user:
        user.age = 30
        user.is_active = False
        user.updated_at = datetime.utcnow()
        db.session.commit()
        print(f"用户信息已更新")

    # 3. 批量更新
    # 将所有不活跃用户的年龄设为0
    updated_count = User.query.filter_by(is_active=False).update(
        {'age': 0},
        synchronize_session=False
    )
    db.session.commit()
    print(f"批量更新了 {updated_count} 个用户")

    # 4. 递增/递减操作
    post = Post.query.first()
    if post:
        post.views = Post.views + 1  # 或使用 increment_views 方法
        db.session.commit()
        print(f"文章浏览量已增加")

    # 5. 使用自定义方法更新
    class User(db.Model):
        # ... 其他字段

        def update_profile(self, **kwargs):
            """更新用户资料"""
            for key, value in kwargs.items():
                if hasattr(self, key):
                    setattr(self, key, value)
            self.updated_at = datetime.utcnow()
            db.session.commit()

    # 使用自定义方法
    user = User.query.get(1)
    user.update_profile(username='updated_name', age=35)

    # 6. 更新关联数据
    user = User.query.get(1)
    if user and user.posts:
        # 更新用户的所有文章
        for post in user.posts:
            post.views += 10
        db.session.commit()
        print(f"更新了用户的所有文章浏览量")

    return user

4.1 乐观锁实现

class Product(db.Model):
    """使用版本控制的商品模型"""
    __tablename__ = 'products'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(100), nullable=False)
    stock = db.Column(db.Integer, default=0)
    price = db.Column(db.Numeric(10, 2), default=0.00)
    version = db.Column(db.Integer, default=0)  # 版本号

    def decrease_stock(self, quantity):
        """减少库存(乐观锁)"""
        from sqlalchemy import update

        if self.stock < quantity:
            raise ValueError("库存不足")

        # 使用版本控制防止并发问题
        stmt = update(Product).where(
            Product.id == self.id,
            Product.version == self.version
        ).values(
            stock=Product.stock - quantity,
            version=Product.version + 1
        )

        result = db.session.execute(stmt)
        db.session.commit()

        if result.rowcount == 0:
            # 版本不匹配,说明数据已被修改
            db.session.rollback()
            raise ValueError("更新冲突,请重试")

        # 刷新本地对象
        db.session.refresh(self)
        print(f"库存已减少 {quantity},当前库存: {self.stock}")

    def increase_price(self, percentage):
        """增加价格(乐观锁)"""
        from sqlalchemy import update

        new_price = self.price * (1 + percentage / 100)

        stmt = update(Product).where(
            Product.id == self.id,
            Product.version == self.version
        ).values(
            price=new_price,
            version=Product.version + 1
        )

        result = db.session.execute(stmt)
        db.session.commit()

        if result.rowcount == 0:
            db.session.rollback()
            raise ValueError("更新冲突,请重试")

        db.session.refresh(self)

5. 删除数据(Delete)

def delete_operations():
    """删除操作"""

    # 1. 删除单条记录
    user = User.query.get(10)  # 假设删除ID为10的用户
    if user:
        db.session.delete(user)
        db.session.commit()
        print(f"用户 {user.username} 已删除")

    # 2. 条件删除
    deleted_count = User.query.filter_by(is_active=False).delete()
    db.session.commit()
    print(f"删除了 {deleted_count} 个不活跃用户")

    # 3. 删除所有记录(谨慎使用!)
    # all_deleted = User.query.delete()
    # db.session.commit()

    # 4. 级联删除(通过关系)
    user = User.query.get(1)
    if user:
        # 删除用户及其所有文章(因为设置了cascade)
        db.session.delete(user)
        db.session.commit()
        print(f"用户及其关联数据已删除")

    # 5. 软删除(标记删除而非物理删除)
    class SoftDeleteMixin:
        """软删除混入类"""
        is_deleted = db.Column(db.Boolean, default=False)
        deleted_at = db.Column(db.DateTime)

        def soft_delete(self):
            """软删除"""
            self.is_deleted = True
            self.deleted_at = datetime.utcnow()
            db.session.commit()

        @classmethod
        def query_active(cls):
            """查询未删除的记录"""
            return cls.query.filter_by(is_deleted=False)

    # 6. 删除关联表中的记录
    # 删除文章的所有分类关联
    post = Post.query.get(1)
    if post:
        post.categories = []  # 清空分类关联
        db.session.commit()
        print(f"文章分类关联已清空")

    # 或者直接从关联表删除
    from sqlalchemy import delete as sql_delete
    stmt = sql_delete(post_categories).where(
        post_categories.c.post_id == 1
    )
    db.session.execute(stmt)
    db.session.commit()

    return deleted_count

6. 事务管理

def transaction_examples():
    """事务管理示例"""

    # 1. 基本事务
    try:
        user1 = User(username='user1', email='user1@example.com')
        user2 = User(username='user2', email='user2@example.com')

        db.session.add(user1)
        db.session.add(user2)

        # 模拟一个可能失败的操作
        if user1.username == user2.username:
            raise ValueError("用户名重复")

        db.session.commit()
        print("事务提交成功")

    except Exception as e:
        db.session.rollback()
        print(f"事务回滚: {e}")

    # 2. 嵌套事务
    try:
        # 开始事务
        user = User(username='nested_user', email='nested@example.com')
        db.session.add(user)

        # 嵌套操作
        try:
            post = Post(title='Nested Post', content='...', author_id=user.id)
            db.session.add(post)
            db.session.commit()
            print("嵌套事务提交成功")
        except Exception as e:
            db.session.rollback()
            print(f"嵌套事务回滚: {e}")
            raise  # 重新抛出异常

    except Exception as e:
        db.session.rollback()
        print(f"外层事务回滚: {e}")

    # 3. 使用上下文管理器
    from contextlib import contextmanager

    @contextmanager
    def transaction():
        """事务上下文管理器"""
        try:
            yield
            db.session.commit()
        except Exception as e:
            db.session.rollback()
            raise e

    # 使用事务上下文管理器
    with transaction():
        user = User(username='ctx_user', email='ctx@example.com')
        post = Post(title='Ctx Post', content='...', author_id=user.id)
        db.session.add_all([user, post])
        print("上下文管理器事务完成")

    # 4. 保存点(复杂事务)
    try:
        user1 = User(username='savepoint1', email='sp1@example.com')
        db.session.add(user1)

        # 创建保存点
        savepoint = db.session.begin_nested()

        try:
            user2 = User(username='savepoint2', email='sp2@example.com')
            db.session.add(user2)

            # 模拟失败
            if True:  # 测试用,总是触发回滚到保存点
                raise ValueError("保存点内的操作失败")

            savepoint.commit()
        except Exception as e:
            savepoint.rollback()
            print(f"保存点回滚: {e}")
            # 外层事务继续

        db.session.commit()
        print("主事务提交成功")

    except Exception as e:
        db.session.rollback()
        print(f"主事务回滚: {e}")

7. 分页查询

def pagination_examples():
    """分页查询示例"""

    # 1. 使用SQLAlchemy的分页器
    def get_users_page(page=1, per_page=10):
        """获取用户分页"""
        pagination = User.query.filter_by(is_active=True).order_by(
            User.created_at.desc()
        ).paginate(
            page=page,
            per_page=per_page,
            error_out=False  # 如果page超出范围不抛出异常
        )

        return {
            'users': pagination.items,
            'page': pagination.page,
            'per_page': pagination.per_page,
            'total': pagination.total,
            'pages': pagination.pages,
            'has_prev': pagination.has_prev,
            'has_next': pagination.has_next,
            'prev_num': pagination.prev_num,
            'next_num': pagination.next_num
        }

    # 2. 手动实现分页
    def manual_pagination(page=1, per_page=10):
        """手动分页"""
        offset = (page - 1) * per_page

        users = User.query.order_by(User.id).offset(offset).limit(per_page).all()
        total = User.query.count()

        return {
            'users': users,
            'page': page,
            'per_page': per_page,
            'total': total,
            'pages': (total + per_page - 1) // per_page
        }

    # 3. 在视图函数中使用分页
    from flask import request, render_template

    @app.route('/users')
    def list_users():
        page = request.args.get('page', 1, type=int)
        per_page = request.args.get('per_page', 10, type=int)

        pagination = User.query.paginate(page=page, per_page=per_page)

        return render_template('users/list.html', pagination=pagination)

    # 4. 复杂查询的分页
    def get_posts_by_category(category_id, page=1, per_page=10):
        """按分类获取文章分页"""
        query = Post.query.join(
            post_categories
        ).filter(
            post_categories.c.category_id == category_id,
            Post.is_published == True
        ).order_by(
            Post.created_at.desc()
        )

        pagination = query.paginate(page=page, per_page=per_page)

        return pagination

    # 5. 分页模板示例
    '''
    <!-- templates/users/list.html -->
    @% for user in pagination.items %>
        <div>{{ user.username }}</div>
    @% endfor %>

    <!-- 分页导航 -->
    <nav>
        <ul class="pagination">
            @% if pagination.has_prev %>
                <li class="page-item">
                    <a class="page-link" href="?page={{ pagination.prev_num }}">上一页</a>
                </li>
            @% endif %>

            @% for page_num in pagination.iter_pages() %>
                @% if page_num %>
                    <li class="page-item {{ 'active' if page_num == pagination.page else '' }}">
                        <a class="page-link" href="?page={{ page_num }}">{{ page_num }}</a>
                    </li>
                @% else %>
                    <li class="page-item disabled">
                        <span class="page-link">...</span>
                    </li>
                @% endif %>
            @% endfor %>

            @% if pagination.has_next %>
                <li class="page-item">
                    <a class="page-link" href="?page={{ pagination.next_num }}">下一页</a>
                </li>
            @% endif %>
        </ul>
    </nav>
    '''

    return get_users_page(1, 10)

8. 高级查询技巧

def advanced_query_techniques():
    """高级查询技巧"""

    from sqlalchemy import and_, or_, not_
    from sqlalchemy.orm import aliased

    # 1. 使用and_, or_, not_组合条件
    complex_query = User.query.filter(
        or_(
            and_(User.age >= 18, User.age <= 30),
            User.is_active == True
        ),
        not_(User.username.like('%admin%'))
    ).all()

    # 2. 子查询
    from sqlalchemy import select

    # 创建子查询:文章数量大于平均值的用户
    subquery = select([User.id]).join(Post).group_by(User.id).having(
        func.count(Post.id) > select([func.avg(func.count(Post.id))]).select_from(
            User
        ).join(Post).group_by(User.id).as_scalar()
    ).alias()

    active_authors = User.query.filter(User.id.in_(subquery)).all()

    # 3. 使用别名
    UserAlias = aliased(User)
    query = db.session.query(User, UserAlias).filter(
        User.id != UserAlias.id,
        User.email == UserAlias.email
    ).all()

    # 4. 窗口函数(需要数据库支持)
    from sqlalchemy import over

    # 给用户按年龄排名
    ranked_users = db.session.query(
        User,
        func.row_number().over(
            order_by=User.age.desc(),
            partition_by=User.is_active
        ).label('rank')
    ).all()

    # 5. 递归查询(用于树状结构)
    class Category(db.Model):
        # ... 其他字段
        parent_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
        parent = db.relationship('Category', remote_side=[id], backref='children')

    # 查询某个分类的所有子分类
    def get_all_subcategories(category_id):
        from sqlalchemy.orm import aliased

        CategoryAlias = aliased(Category)

        # 使用递归CTE(需要数据库支持)
        # 这里简化处理:多次查询
        all_categories = []
        current_level = [category_id]

        while current_level:
            next_level = Category.query.filter(
                Category.parent_id.in_(current_level)
            ).all()

            all_categories.extend(next_level)
            current_level = [cat.id for cat in next_level]

        return all_categories

    # 6. 全文搜索(需要数据库支持全文索引)
    def fulltext_search(search_term):
        """全文搜索示例"""
        # 使用LIKE模拟
        results = Post.query.filter(
            or_(
                Post.title.ilike(f'%{search_term}%'),
                Post.content.ilike(f'%{search_term}%')
            )
        ).all()

        # 如果使用PostgreSQL,可以使用更高级的全文搜索
        # from sqlalchemy import text
        # results = Post.query.filter(
        #     text("to_tsvector('english', title || ' ' || content) @@ plainto_tsquery(:query)")
        # ).params(query=search_term).all()

        return results

    return complex_query

9. 性能优化

def performance_optimization():
    """性能优化技巧"""

    from sqlalchemy.orm import joinedload, selectinload, subqueryload

    # 1. 解决N+1查询问题
    # ❌ 不好的做法:每次循环都会查询数据库
    users = User.query.all()
    for user in users:
        posts = user.posts.all()  # N+1查询!

    # ✅ 好的做法:使用joinedload急切加载
    users_with_posts = User.query.options(joinedload(User.posts)).all()
    for user in users_with_posts:
        posts = user.posts  # 已加载,不会再次查询

    # 2. 使用selectinload(适合一对多关系)
    users = User.query.options(selectinload(User.posts)).all()

    # 3. 只选择需要的字段
    # ❌ 选择所有字段
    users = User.query.all()

    # ✅ 只选择需要的字段
    users_light = db.session.query(User.id, User.username, User.email).all()

    # 4. 使用索引优化查询
    # 确保常用查询字段有索引
    # 在模型定义中添加:index=True

    # 5. 批量操作代替循环
    # ❌ 循环更新
    users = User.query.all()
    for user in users:
        user.login_count += 1
    db.session.commit()

    # ✅ 批量更新
    User.query.update({'login_count': User.login_count + 1})
    db.session.commit()

    # 6. 使用数据库的延迟约束
    # 在大量插入时,可以暂时禁用外键检查
    if app.config['SQLALCHEMY_DATABASE_URI'].startswith('sqlite'):
        db.session.execute('PRAGMA foreign_keys = OFF')
        # 执行批量插入
        db.session.execute('PRAGMA foreign_keys = ON')

    # 7. 分页而不是一次性加载所有数据
    # 总是使用分页处理大量数据

    # 8. 使用缓存
    from flask_caching import Cache

    cache = Cache(app, config={'CACHE_TYPE': 'simple'})

    @cache.memoize(timeout=300)  # 缓存5分钟
    def get_active_users_count():
        return User.query.filter_by(is_active=True).count()

    return users_with_posts

10. 实用工具函数

class DatabaseUtils:
    """数据库工具类"""

    @staticmethod
    def get_or_create(model, **kwargs):
        """获取或创建记录"""
        instance = model.query.filter_by(**kwargs).first()
        if instance:
            return instance, False  # False表示已存在

        instance = model(**kwargs)
        db.session.add(instance)
        db.session.commit()
        return instance, True  # True表示新创建的

    @staticmethod
    def update_or_create(model, defaults=None, **kwargs):
        """更新或创建记录"""
        instance = model.query.filter_by(**kwargs).first()

        if instance:
            # 更新现有记录
            for key, value in (defaults or {}).items():
                setattr(instance, key, value)
            created = False
        else:
            # 创建新记录
            params = kwargs.copy()
            params.update(defaults or {})
            instance = model(**params)
            created = True

        db.session.add(instance)
        db.session.commit()
        return instance, created

    @staticmethod
    def bulk_upsert(model, data_list, conflict_fields):
        """批量插入或更新"""
        # 简化实现,实际应用中可能需要更复杂的逻辑
        for data in data_list:
            filter_args = {field: data[field] for field in conflict_fields}
            instance = model.query.filter_by(**filter_args).first()

            if instance:
                for key, value in data.items():
                    setattr(instance, key, value)
            else:
                instance = model(**data)
                db.session.add(instance)

        db.session.commit()

    @staticmethod
    def export_to_csv(model, filepath):
        """导出数据到CSV"""
        import csv

        records = model.query.all()
        if not records:
            return False

        # 获取字段名
        fields = [column.name for column in model.__table__.columns]

        with open(filepath, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(fields)

            for record in records:
                row = [getattr(record, field) for field in fields]
                writer.writerow(row)

        return True

    @staticmethod
    def import_from_csv(model, filepath):
        """从CSV导入数据"""
        import csv

        with open(filepath, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                instance = model(**row)
                db.session.add(instance)

        db.session.commit()
        return True

# 使用示例
user, created = DatabaseUtils.get_or_create(
    User,
    username='test_user',
    defaults={'email': 'test@example.com', 'age': 25}
)