Flask数据库连接

Flask可以使用多种方式连接数据库,最常用的是通过SQLAlchemy扩展,它支持SQLite、MySQL、PostgreSQL等多种数据库。

1. 安装必要的扩展

首先需要安装Flask-SQLAlchemy扩展:

# 安装Flask-SQLAlchemy
pip install Flask-SQLAlchemy

# 如果需要数据库迁移功能,安装Flask-Migrate
pip install Flask-Migrate

# 根据数据库类型安装相应的驱动
# SQLite(Python自带,无需额外安装)

# MySQL
pip install pymysql
# 或
pip install mysqlclient

# PostgreSQL
pip install psycopg2
# 或
pip install psycopg2-binary

# SQL Server
pip install pyodbc

2. 数据库连接配置

2.1 基本配置

from flask import Flask
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)

# 数据库配置
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///site.db'  # SQLite数据库
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False  # 关闭修改跟踪(减少内存使用)
app.config['SQLALCHEMY_ECHO'] = True  # 设置为True可查看生成的SQL语句(调试用)

# 创建SQLAlchemy实例
db = SQLAlchemy(app)

# 可选:配置其他数据库连接参数
app.config.update({
    'SQLALCHEMY_POOL_SIZE': 10,  # 连接池大小
    'SQLALCHEMY_POOL_TIMEOUT': 30,  # 连接超时时间
    'SQLALCHEMY_POOL_RECYCLE': 3600,  # 连接回收时间(秒)
    'SQLALCHEMY_MAX_OVERFLOW': 20,  # 连接池最大溢出数
})

2.2 不同数据库的连接字符串

数据库 连接字符串格式 示例
SQLite sqlite:///database.db sqlite:///site.db
MySQL mysql://username:password@host/database mysql://root:password@localhost/mydb
PostgreSQL postgresql://username:password@host/database postgresql://postgres:password@localhost/mydb
SQL Server mssql+pyodbc://username:password@dsn mssql+pyodbc://sa:password@mydb
Oracle oracle://username:password@host:port/database oracle://scott:tiger@localhost:1521/orcl

2.3 完整的配置示例

import os
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.engine import Engine
from sqlalchemy import event

app = Flask(__name__)

# 基础配置类
class Config:
    SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key'
    SQLALCHEMY_TRACK_MODIFICATIONS = False
    SQLALCHEMY_RECORD_QUERIES = False  # 是否记录查询(性能分析用)

    # SQLite配置
    SQLALCHEMY_DATABASE_URI = 'sqlite:///app.db'

    # MySQL配置(根据环境变量切换)
    # SQLALCHEMY_DATABASE_URI = 'mysql+pymysql://root:password@localhost/mydb'
    # SQLALCHEMY_ENGINE_OPTIONS = {
    #     'pool_size': 10,
    #     'pool_recycle': 3600,
    #     'pool_pre_ping': True,  # 连接前检查连接是否有效
    # }

# 根据环境选择配置
if os.environ.get('FLASK_ENV') == 'production':
    app.config.from_object(ProductionConfig)
elif os.environ.get('FLASK_ENV') == 'testing':
    app.config.from_object(TestingConfig)
else:
    app.config.from_object(DevelopmentConfig)

# 启用SQLite外键约束(SQLite默认关闭)
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
    if app.config['SQLALCHEMY_DATABASE_URI'].startswith('sqlite'):
        cursor = dbapi_connection.cursor()
        cursor.execute("PRAGMA foreign_keys=ON")
        cursor.close()

# 初始化数据库
db = SQLAlchemy(app)

3. 定义数据模型

from datetime import datetime

# 用户模型
class User(db.Model):
    """用户表"""
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False)
    email = db.Column(db.String(120), unique=True, nullable=False)
    password_hash = db.Column(db.String(128), nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    is_active = db.Column(db.Boolean, default=True)

    # 关系:一对多(用户可以有多个文章)
    posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')

    # 关系:多对多(用户关注其他用户)
    following = db.relationship('User',
                              secondary='follows',
                              primaryjoin='Follow.follower_id == User.id',
                              secondaryjoin='Follow.followed_id == User.id',
                              backref=db.backref('followers', lazy='dynamic'),
                              lazy='dynamic')

    def __repr__(self):
        return f'<User {self.username}>'

# 文章模型
class Post(db.Model):
    """文章表"""
    __tablename__ = 'posts'

    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(200), nullable=False)
    content = db.Column(db.Text, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    is_published = db.Column(db.Boolean, default=False)

    # 外键:作者ID
    author_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)

    # 关系:一对多(文章可以有多个评论)
    comments = db.relationship('Comment', backref='post', lazy='dynamic', cascade='all, delete-orphan')

    # 关系:多对多(文章分类)
    categories = db.relationship('Category',
                                secondary='post_categories',
                                backref=db.backref('posts', lazy='dynamic'),
                                lazy='dynamic')

    def __repr__(self):
        return f'<Post {self.title}>'

# 评论模型
class Comment(db.Model):
    """评论表"""
    __tablename__ = 'comments'

    id = db.Column(db.Integer, primary_key=True)
    content = db.Column(db.Text, nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    is_approved = db.Column(db.Boolean, default=True)

    # 外键:文章ID和用户ID
    post_id = db.Column(db.Integer, db.ForeignKey('posts.id'), nullable=False)
    user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)

    # 关系
    user = db.relationship('User', backref='comments')

    def __repr__(self):
        return f'<Comment {self.id}>'

# 分类模型
class Category(db.Model):
    """分类表"""
    __tablename__ = 'categories'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String(50), unique=True, nullable=False)
    slug = db.Column(db.String(50), unique=True, nullable=False)
    description = db.Column(db.Text)

    def __repr__(self):
        return f'<Category {self.name}>'

# 关联表:关注关系
class Follow(db.Model):
    """用户关注关系表"""
    __tablename__ = 'follows'

    follower_id = db.Column(db.Integer, db.ForeignKey('users.id'), primary_key=True)
    followed_id = db.Column(db.Integer, db.ForeignKey('users.id'), primary_key=True)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)

# 关联表:文章分类
post_categories = db.Table('post_categories',
    db.Column('post_id', db.Integer, db.ForeignKey('posts.id'), primary_key=True),
    db.Column('category_id', db.Integer, db.ForeignKey('categories.id'), primary_key=True),
    db.Column('created_at', db.DateTime, default=datetime.utcnow)
)

4. 创建和初始化数据库

4.1 使用Flask命令行

# 创建数据库表
from app import app, db

with app.app_context():
    # 删除所有表(谨慎使用!)
    # db.drop_all()

    # 创建所有表
    db.create_all()

    print("数据库表创建成功!")

# 或者使用Flask命令行
# 在终端执行:
# flask shell
# >>> from app import db
# >>> db.create_all()

4.2 使用Flask-Migrate进行数据库迁移

# app.py
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///app.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

db = SQLAlchemy(app)
migrate = Migrate(app, db)

# 模型定义...
# 初始化迁移仓库
flask db init

# 生成迁移脚本
flask db migrate -m "创建用户表"

# 应用迁移
flask db upgrade

# 回滚迁移
flask db downgrade

# 查看迁移历史
flask db history

# 显示当前迁移状态
flask db show

5. 数据库基本操作

5.1 CRUD操作示例

from datetime import datetime

def create_user():
    """创建用户"""
    # 创建新用户
    user = User(
        username='john_doe',
        email='john@example.com',
        password_hash='hashed_password_here',
        created_at=datetime.utcnow()
    )

    # 添加到会话
    db.session.add(user)

    # 提交到数据库
    try:
        db.session.commit()
        print(f"用户 {user.username} 创建成功")
    except Exception as e:
        db.session.rollback()
        print(f"创建用户失败: {e}")

    return user

def get_users():
    """查询用户"""
    # 查询所有用户
    users = User.query.all()

    # 分页查询
    page = 1
    per_page = 10
    users_paginated = User.query.paginate(page=page, per_page=per_page)

    # 条件查询
    active_users = User.query.filter_by(is_active=True).all()

    # 复杂查询
    recent_users = User.query.filter(
        User.created_at >= datetime(2023, 1, 1)
    ).order_by(User.created_at.desc()).limit(10).all()

    # 查询单个用户
    user = User.query.filter_by(username='john_doe').first()

    return users

def update_user(user_id):
    """更新用户"""
    user = User.query.get(user_id)
    if user:
        user.email = 'new_email@example.com'
        user.updated_at = datetime.utcnow()

        try:
            db.session.commit()
            print(f"用户 {user.username} 更新成功")
        except Exception as e:
            db.session.rollback()
            print(f"更新用户失败: {e}")

    return user

def delete_user(user_id):
    """删除用户"""
    user = User.query.get(user_id)
    if user:
        db.session.delete(user)

        try:
            db.session.commit()
            print(f"用户 {user.username} 删除成功")
        except Exception as e:
            db.session.rollback()
            print(f"删除用户失败: {e}")

def create_post():
    """创建文章"""
    user = User.query.filter_by(username='john_doe').first()

    if user:
        post = Post(
            title='我的第一篇博客文章',
            content='这是文章内容...',
            author_id=user.id,
            is_published=True
        )

        # 添加分类
        category1 = Category(name='Python', slug='python')
        category2 = Category(name='Flask', slug='flask')
        db.session.add(category1)
        db.session.add(category2)

        post.categories.append(category1)
        post.categories.append(category2)

        db.session.add(post)
        db.session.commit()

        return post

def complex_queries():
    """复杂查询示例"""
    # JOIN查询
    posts_with_authors = db.session.query(Post, User).join(User).all()

    # 聚合查询
    from sqlalchemy import func
    post_count = db.session.query(func.count(Post.id)).scalar()
    user_post_counts = db.session.query(
        User.username,
        func.count(Post.id).label('post_count')
    ).join(Post).group_by(User.id).all()

    # 子查询
    subquery = db.session.query(
        Post.author_id.label('author_id'),
        func.count(Post.id).label('post_count')
    ).group_by(Post.author_id).subquery()

    users_with_post_count = db.session.query(
        User, subquery.c.post_count
    ).outerjoin(subquery, User.id == subquery.c.author_id).all()

    return posts_with_authors

6. 在视图函数中使用数据库

from flask import Flask, render_template, request, redirect, url_for, flash, jsonify
from flask_login import login_user, logout_user, login_required, current_user

@app.route('/register', methods=['GET', 'POST'])
def register():
    if request.method == 'POST':
        username = request.form.get('username')
        email = request.form.get('email')
        password = request.form.get('password')

        # 检查用户名和邮箱是否已存在
        existing_user = User.query.filter(
            (User.username == username) | (User.email == email)
        ).first()

        if existing_user:
            if existing_user.username == username:
                flash('用户名已存在', 'error')
            else:
                flash('邮箱已存在', 'error')
            return redirect(url_for('register'))

        # 创建新用户
        new_user = User(
            username=username,
            email=email,
            password_hash=generate_password_hash(password)
        )

        try:
            db.session.add(new_user)
            db.session.commit()
            flash('注册成功!请登录。', 'success')
            return redirect(url_for('login'))
        except Exception as e:
            db.session.rollback()
            flash(f'注册失败: {str(e)}', 'error')
            return redirect(url_for('register'))

    return render_template('register.html')

@app.route('/posts')
def list_posts():
    """文章列表页面"""
    page = request.args.get('page', 1, type=int)
    per_page = 10

    # 获取查询参数
    category_id = request.args.get('category', type=int)
    search_query = request.args.get('q', '')

    # 构建查询
    query = Post.query.filter_by(is_published=True)

    # 按分类筛选
    if category_id:
        query = query.filter(Post.categories.any(id=category_id))

    # 搜索功能
    if search_query:
        query = query.filter(
            (Post.title.contains(search_query)) |
            (Post.content.contains(search_query))
        )

    # 排序和分页
    posts = query.order_by(Post.created_at.desc()).paginate(
        page=page,
        per_page=per_page,
        error_out=False
    )

    # 获取所有分类(用于筛选器)
    categories = Category.query.all()

    return render_template('posts.html',
                         posts=posts,
                         categories=categories,
                         current_category=category_id,
                         search_query=search_query)

@app.route('/api/posts', methods=['GET'])
def api_get_posts():
    """API:获取文章列表"""
    page = request.args.get('page', 1, type=int)
    per_page = request.args.get('per_page', 10, type=int)

    posts = Post.query.filter_by(is_published=True).paginate(
        page=page,
        per_page=per_page
    )

    return jsonify({
        'success': True,
        'data': [{
            'id': post.id,
            'title': post.title,
            'content': post.content[:200],  # 截取前200个字符
            'created_at': post.created_at.isoformat(),
            'author': post.author.username
        } for post in posts.items],
        'pagination': {
            'page': posts.page,
            'per_page': posts.per_page,
            'total': posts.total,
            'pages': posts.pages
        }
    })

@app.route('/admin/dashboard')
@login_required
def admin_dashboard():
    """管理后台仪表板"""
    # 统计信息
    stats = {
        'total_users': User.query.count(),
        'total_posts': Post.query.count(),
        'total_comments': Comment.query.count(),
        'published_posts': Post.query.filter_by(is_published=True).count(),
        'recent_users': User.query.order_by(User.created_at.desc()).limit(5).all()
    }

    # 最新活动
    recent_posts = Post.query.order_by(Post.created_at.desc()).limit(10).all()
    recent_comments = Comment.query.order_by(Comment.created_at.desc()).limit(10).all()

    return render_template('admin/dashboard.html',
                         stats=stats,
                         recent_posts=recent_posts,
                         recent_comments=recent_comments)

7. 连接池和性能优化

from sqlalchemy.pool import QueuePool
from sqlalchemy import create_engine

# 自定义数据库引擎配置
def create_database_engine():
    """创建数据库引擎(用于高级配置)"""
    database_url = app.config['SQLALCHEMY_DATABASE_URI']

    engine = create_engine(
        database_url,
        poolclass=QueuePool,  # 使用连接池
        pool_size=app.config.get('SQLALCHEMY_POOL_SIZE', 10),
        max_overflow=app.config.get('SQLALCHEMY_MAX_OVERFLOW', 20),
        pool_timeout=app.config.get('SQLALCHEMY_POOL_TIMEOUT', 30),
        pool_recycle=app.config.get('SQLALCHEMY_POOL_RECYCLE', 3600),
        pool_pre_ping=True,  # 连接前检查连接是否有效
        echo=app.config.get('SQLALCHEMY_ECHO', False)  # 是否输出SQL语句
    )

    return engine

# 性能优化建议
class OptimizedQuery:
    """查询优化示例"""

    @staticmethod
    def eager_load_relationships():
        """使用急切加载减少查询次数"""
        # ❌ 不好的做法:N+1查询问题
        posts = Post.query.all()
        for post in posts:
            print(post.author.username)  # 每次循环都会查询数据库

        # ✅ 好的做法:使用join或joinedload
        from sqlalchemy.orm import joinedload

        # 方法1:使用join
        posts = db.session.query(Post).join(User).all()

        # 方法2:使用joinedload
        posts = Post.query.options(joinedload(Post.author)).all()

        return posts

    @staticmethod
    def use_select_only():
        """只选择需要的字段"""
        # ❌ 选择所有字段
        users = User.query.all()

        # ✅ 只选择需要的字段
        users = db.session.query(User.id, User.username, User.email).all()

        return users

    @staticmethod
    def paginate_large_datasets():
        """分页处理大数据集"""
        page = 1
        per_page = 50

        # 使用SQLAlchemy的分页
        pagination = User.query.paginate(page=page, per_page=per_page)

        # 或者使用limit和offset
        users = User.query.offset((page-1)*per_page).limit(per_page).all()

        return pagination

    @staticmethod
    def use_indexes():
        """为常用查询字段创建索引"""
        # 在模型定义中添加索引
        class User(db.Model):
            __tablename__ = 'users'
            __table_args__ = (
                db.Index('idx_username', 'username'),  # 为username创建索引
                db.Index('idx_email', 'email'),        # 为email创建索引
                db.Index('idx_created_at', 'created_at'),  # 为创建时间创建索引
            )

            id = db.Column(db.Integer, primary_key=True)
            username = db.Column(db.String(80), nullable=False)
            # ... 其他字段

8. 多数据库配置

# 多数据库配置示例
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)

# 配置多个数据库
app.config['SQLALCHEMY_BINDS'] = {
    'users': 'sqlite:///users.db',
    'posts': 'sqlite:///posts.db',
    'analytics': 'mysql://root:password@localhost/analytics'
}

app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

db = SQLAlchemy(app)

# 绑定到特定数据库的模型
class User(db.Model):
    """用户模型 - 使用users数据库"""
    __bind_key__ = 'users'  # 指定数据库绑定
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True)
    # ... 其他字段

class Post(db.Model):
    """文章模型 - 使用posts数据库"""
    __bind_key__ = 'posts'  # 指定数据库绑定
    __tablename__ = 'posts'

    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String(200))
    # ... 其他字段

class Analytics(db.Model):
    """分析数据模型 - 使用analytics数据库"""
    __bind_key__ = 'analytics'  # 指定数据库绑定
    __tablename__ = 'analytics'

    id = db.Column(db.Integer, primary_key=True)
    event = db.Column(db.String(50))
    timestamp = db.Column(db.DateTime, default=datetime.utcnow)
    # ... 其他字段

# 使用特定数据库连接
def query_multiple_databases():
    """查询多个数据库"""
    # 查询users数据库
    users = db.session.query(User).all()

    # 查询posts数据库
    posts = db.session.query(Post).all()

    # 手动指定数据库连接
    with db.engine.connect(bind_key='analytics') as conn:
        result = conn.execute("SELECT * FROM analytics WHERE event = 'page_view'")
        analytics_data = result.fetchall()

    return users, posts, analytics_data

9. 最佳实践和常见问题

  • 使用环境变量存储敏感信息:不要将数据库密码硬编码在代码中
  • 使用连接池:提高数据库连接性能
  • 启用SQL日志:开发环境中启用SQLALCHEMY_ECHO调试SQL
  • 使用数据库迁移:始终使用Flask-Migrate管理数据库结构变更
  • 处理事务:使用try-except块处理数据库操作异常
  • 避免N+1查询:使用joinedloadselect_related优化关联查询
  • 定期备份:定期备份数据库,特别是生产环境

10. 环境配置示例

import os
from dotenv import load_dotenv

# 加载环境变量
load_dotenv()

class Config:
    """基础配置"""
    SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key'
    SQLALCHEMY_TRACK_MODIFICATIONS = False

    # 数据库配置
    DB_HOST = os.environ.get('DB_HOST', 'localhost')
    DB_PORT = os.environ.get('DB_PORT', '3306')
    DB_USER = os.environ.get('DB_USER', 'root')
    DB_PASSWORD = os.environ.get('DB_PASSWORD', '')
    DB_NAME = os.environ.get('DB_NAME', 'flask_app')

class DevelopmentConfig(Config):
    """开发环境配置"""
    DEBUG = True
    SQLALCHEMY_ECHO = True
    SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://{Config.DB_USER}:{Config.DB_PASSWORD}@{Config.DB_HOST}:{Config.DB_PORT}/{Config.DB_NAME}"

class TestingConfig(Config):
    """测试环境配置"""
    TESTING = True
    SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'

class ProductionConfig(Config):
    """生产环境配置"""
    DEBUG = False
    # 生产环境使用更安全的连接配置
    SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://{Config.DB_USER}:{Config.DB_PASSWORD}@{Config.DB_HOST}:{Config.DB_PORT}/{Config.DB_NAME}"
    SQLALCHEMY_ENGINE_OPTIONS = {
        'pool_size': 20,
        'pool_recycle': 3600,
        'pool_pre_ping': True,
        'max_overflow': 30,
        'echo': False
    }
    # 生产环境建议使用SSL连接
    # SQLALCHEMY_DATABASE_URI += '?ssl=true'

# 根据环境选择配置
config = {
    'development': DevelopmentConfig,
    'testing': TestingConfig,
    'production': ProductionConfig,
    'default': DevelopmentConfig
}

def create_app(config_name='default'):
    """应用工厂函数"""
    app = Flask(__name__)

    # 加载配置
    app.config.from_object(config[config_name])

    # 初始化扩展
    db.init_app(app)

    return app