首先定义一些示例模型,用于演示数据库操作:
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///test.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
db = SQLAlchemy(app)
class User(db.Model):
"""用户模型"""
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False)
email = db.Column(db.String(120), unique=True, nullable=False)
age = db.Column(db.Integer)
is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 一对多关系:一个用户有多篇文章
posts = db.relationship('Post', backref='author', lazy='dynamic', cascade='all, delete-orphan')
def __repr__(self):
return f'<User {self.username}>'
class Post(db.Model):
"""文章模型"""
__tablename__ = 'posts'
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200), nullable=False)
content = db.Column(db.Text, nullable=False)
views = db.Column(db.Integer, default=0)
is_published = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
# 外键:作者ID
author_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
def __repr__(self):
return f'<Post {self.title}>'
class Category(db.Model):
"""分类模型"""
__tablename__ = 'categories'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.Text)
# 多对多关系:文章和分类
posts = db.relationship('Post', secondary='post_categories', backref='categories')
def __repr__(self):
return f'<Category {self.name}>'
# 关联表:文章分类
post_categories = db.Table('post_categories',
db.Column('post_id', db.Integer, db.ForeignKey('posts.id'), primary_key=True),
db.Column('category_id', db.Integer, db.ForeignKey('categories.id'), primary_key=True),
db.Column('created_at', db.Column(db.DateTime, default=datetime.utcnow))
)
# 方法1:使用模型构造函数创建对象
def create_single_record():
"""创建单条记录"""
# 创建用户对象
user = User(
username='john_doe',
email='john@example.com',
age=25,
is_active=True
)
# 添加到会话
db.session.add(user)
# 提交到数据库
try:
db.session.commit()
print(f"用户创建成功: {user.username}")
return user
except Exception as e:
db.session.rollback()
print(f"创建失败: {e}")
return None
# 方法2:直接使用session.add()并提交
def create_user_simple():
"""简化版创建用户"""
user = User(username='jane_doe', email='jane@example.com')
db.session.add(user)
db.session.commit()
return user
# 方法3:使用create()方法(如果模型中有定义)
class User(db.Model):
# ... 其他字段定义
@classmethod
def create(cls, **kwargs):
"""创建用户的便捷方法"""
user = cls(**kwargs)
db.session.add(user)
db.session.commit()
return user
# 使用便捷方法
user = User.create(username='bob', email='bob@example.com')
def create_bulk_records():
"""批量创建记录"""
# 准备用户数据列表
users_data = [
{'username': 'alice', 'email': 'alice@example.com', 'age': 28},
{'username': 'bob', 'email': 'bob@example.com', 'age': 32},
{'username': 'charlie', 'email': 'charlie@example.com', 'age': 24},
{'username': 'david', 'email': 'david@example.com', 'age': 35},
{'username': 'eve', 'email': 'eve@example.com', 'age': 29},
]
# 方法1:使用循环逐个添加
users = []
for data in users_data:
user = User(**data)
db.session.add(user)
users.append(user)
# 方法2:使用add_all()一次性添加
# users = [User(**data) for data in users_data]
# db.session.add_all(users)
try:
db.session.commit()
print(f"批量创建成功,共创建了 {len(users)} 个用户")
return users
except Exception as e:
db.session.rollback()
print(f"批量创建失败: {e}")
return []
def create_with_relationships():
"""创建有关联关系的记录"""
# 创建用户
user = User(username='author1', email='author1@example.com')
db.session.add(user)
db.session.commit() # 需要先提交,获取用户ID
# 创建文章并关联用户
post1 = Post(
title='我的第一篇文章',
content='这是文章内容...',
author_id=user.id
)
post2 = Post(
title='Flask数据库教程',
content='学习Flask数据库操作...',
author_id=user.id
)
# 更简单的方式:直接使用关系
post3 = Post(title='Python入门', content='Python基础教程...')
user.posts.append(post3) # 通过关系添加
# 创建分类
python_category = Category(name='Python', description='Python相关文章')
flask_category = Category(name='Flask', description='Flask框架相关文章')
# 关联文章和分类
post1.categories.append(python_category)
post2.categories.extend([python_category, flask_category])
db.session.add_all([post1, post2, post3, python_category, flask_category])
db.session.commit()
print("创建完成,包含关联关系")
return user
def basic_queries():
"""基本查询操作"""
# 1. 查询所有记录
all_users = User.query.all()
print(f"所有用户: {all_users}")
# 2. 查询第一条记录
first_user = User.query.first()
print(f"第一个用户: {first_user}")
# 3. 通过主键查询
user_by_id = User.query.get(1) # 查询ID为1的用户
print(f"ID为1的用户: {user_by_id}")
# 4. 查询记录数量
user_count = User.query.count()
print(f"用户总数: {user_count}")
# 5. 检查是否存在
exists = User.query.filter_by(username='john_doe').first() is not None
print(f"用户john_doe是否存在: {exists}")
return all_users
def filter_queries():
"""条件过滤查询"""
# 1. 等值过滤
active_users = User.query.filter_by(is_active=True).all()
print(f"活跃用户: {active_users}")
# 2. 比较过滤
adult_users = User.query.filter(User.age >= 18).all()
print(f"成年用户: {adult_users}")
# 3. 多条件过滤
active_adults = User.query.filter(
User.is_active == True,
User.age >= 18
).all()
print(f"活跃的成年用户: {active_adults}")
# 4. LIKE模糊查询
users_like_john = User.query.filter(User.username.like('%john%')).all()
print(f"用户名包含john的用户: {users_like_john}")
# 5. IN查询
specific_users = User.query.filter(User.username.in_(['john_doe', 'jane_doe'])).all()
print(f"特定用户名用户: {specific_users}")
# 6. 范围查询
users_in_range = User.query.filter(User.age.between(20, 30)).all()
print(f"年龄在20-30之间的用户: {users_in_range}")
# 7. NULL值查询
users_without_age = User.query.filter(User.age == None).all()
print(f"年龄为空的用户: {users_without_age}")
return active_users
def sorting_and_limiting():
"""排序和限制查询"""
# 1. 升序排序
users_asc = User.query.order_by(User.username).all()
print(f"按用户名升序: {users_asc}")
# 2. 降序排序
users_desc = User.query.order_by(User.username.desc()).all()
print(f"按用户名降序: {users_desc}")
# 3. 多列排序
users_multi_sort = User.query.order_by(User.is_active.desc(), User.username).all()
print(f"先按活跃状态降序,再按用户名升序: {users_multi_sort}")
# 4. 限制结果数量
recent_users = User.query.order_by(User.created_at.desc()).limit(5).all()
print(f"最近创建的5个用户: {recent_users}")
# 5. 偏移查询(分页)
page_2_users = User.query.order_by(User.id).offset(10).limit(10).all()
print(f"第2页用户(每页10条): {page_2_users}")
# 6. 组合:最近活跃的成年用户
recent_active_adults = User.query.filter(
User.is_active == True,
User.age >= 18
).order_by(
User.last_login.desc()
).limit(10).all()
return recent_users
def relationship_queries():
"""关联关系查询"""
# 1. 查询用户的所有文章
user = User.query.get(1)
if user:
user_posts = user.posts.all() # 使用关系属性
print(f"用户 {user.username} 的文章: {user_posts}")
# 也可以使用查询
posts_by_user = Post.query.filter_by(author_id=user.id).all()
# 2. 查询文章的作者
post = Post.query.first()
if post:
author = post.author # 通过backref访问
print(f"文章 {post.title} 的作者: {author.username}")
# 3. 带条件的关联查询
# 查询所有发表过文章的用户
users_with_posts = User.query.join(Post).filter(Post.is_published == True).all()
print(f"发表过文章的用户: {users_with_posts}")
# 4. 查询文章数量大于N的用户
from sqlalchemy import func
active_authors = db.session.query(
User, func.count(Post.id).label('post_count')
).join(
Post, User.id == Post.author_id
).filter(
Post.is_published == True
).group_by(
User.id
).having(
func.count(Post.id) > 0
).all()
print(f"活跃作者及其文章数量: {active_authors}")
# 5. 多对多关联查询
python_posts = Post.query.join(
post_categories
).join(
Category
).filter(
Category.name == 'Python'
).all()
print(f"Python分类下的文章: {python_posts}")
return users_with_posts
def aggregate_queries():
"""聚合查询"""
from sqlalchemy import func
# 1. 计数
total_users = db.session.query(func.count(User.id)).scalar()
print(f"用户总数: {total_users}")
# 2. 求和
total_age = db.session.query(func.sum(User.age)).scalar()
print(f"用户年龄总和: {total_age}")
# 3. 平均值
avg_age = db.session.query(func.avg(User.age)).scalar()
print(f"平均年龄: {avg_age}")
# 4. 最大值和最小值
max_age = db.session.query(func.max(User.age)).scalar()
min_age = db.session.query(func.min(User.age)).scalar()
print(f"最大年龄: {max_age}, 最小年龄: {min_age}")
# 5. 分组统计
age_groups = db.session.query(
User.age,
func.count(User.id).label('count')
).filter(
User.age != None
).group_by(
User.age
).order_by(
User.age
).all()
print("按年龄分组统计:")
for age, count in age_groups:
print(f" 年龄 {age}: {count} 人")
# 6. 复杂聚合:每个用户的文章数量和总浏览量
user_stats = db.session.query(
User.username,
func.count(Post.id).label('post_count'),
func.sum(Post.views).label('total_views')
).outerjoin(
Post, User.id == Post.author_id
).group_by(
User.id
).order_by(
func.count(Post.id).desc()
).all()
print("用户文章统计:")
for username, post_count, total_views in user_stats:
print(f" {username}: {post_count} 篇文章, {total_views or 0} 次浏览")
return age_groups
def update_operations():
"""更新操作"""
# 1. 更新单个字段
user = User.query.get(1)
if user:
user.username = 'new_username'
db.session.commit()
print(f"用户名已更新: {user.username}")
# 2. 更新多个字段
user = User.query.filter_by(email='john@example.com').first()
if user:
user.age = 30
user.is_active = False
user.updated_at = datetime.utcnow()
db.session.commit()
print(f"用户信息已更新")
# 3. 批量更新
# 将所有不活跃用户的年龄设为0
updated_count = User.query.filter_by(is_active=False).update(
{'age': 0},
synchronize_session=False
)
db.session.commit()
print(f"批量更新了 {updated_count} 个用户")
# 4. 递增/递减操作
post = Post.query.first()
if post:
post.views = Post.views + 1 # 或使用 increment_views 方法
db.session.commit()
print(f"文章浏览量已增加")
# 5. 使用自定义方法更新
class User(db.Model):
# ... 其他字段
def update_profile(self, **kwargs):
"""更新用户资料"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
self.updated_at = datetime.utcnow()
db.session.commit()
# 使用自定义方法
user = User.query.get(1)
user.update_profile(username='updated_name', age=35)
# 6. 更新关联数据
user = User.query.get(1)
if user and user.posts:
# 更新用户的所有文章
for post in user.posts:
post.views += 10
db.session.commit()
print(f"更新了用户的所有文章浏览量")
return user
class Product(db.Model):
"""使用版本控制的商品模型"""
__tablename__ = 'products'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(100), nullable=False)
stock = db.Column(db.Integer, default=0)
price = db.Column(db.Numeric(10, 2), default=0.00)
version = db.Column(db.Integer, default=0) # 版本号
def decrease_stock(self, quantity):
"""减少库存(乐观锁)"""
from sqlalchemy import update
if self.stock < quantity:
raise ValueError("库存不足")
# 使用版本控制防止并发问题
stmt = update(Product).where(
Product.id == self.id,
Product.version == self.version
).values(
stock=Product.stock - quantity,
version=Product.version + 1
)
result = db.session.execute(stmt)
db.session.commit()
if result.rowcount == 0:
# 版本不匹配,说明数据已被修改
db.session.rollback()
raise ValueError("更新冲突,请重试")
# 刷新本地对象
db.session.refresh(self)
print(f"库存已减少 {quantity},当前库存: {self.stock}")
def increase_price(self, percentage):
"""增加价格(乐观锁)"""
from sqlalchemy import update
new_price = self.price * (1 + percentage / 100)
stmt = update(Product).where(
Product.id == self.id,
Product.version == self.version
).values(
price=new_price,
version=Product.version + 1
)
result = db.session.execute(stmt)
db.session.commit()
if result.rowcount == 0:
db.session.rollback()
raise ValueError("更新冲突,请重试")
db.session.refresh(self)
def delete_operations():
"""删除操作"""
# 1. 删除单条记录
user = User.query.get(10) # 假设删除ID为10的用户
if user:
db.session.delete(user)
db.session.commit()
print(f"用户 {user.username} 已删除")
# 2. 条件删除
deleted_count = User.query.filter_by(is_active=False).delete()
db.session.commit()
print(f"删除了 {deleted_count} 个不活跃用户")
# 3. 删除所有记录(谨慎使用!)
# all_deleted = User.query.delete()
# db.session.commit()
# 4. 级联删除(通过关系)
user = User.query.get(1)
if user:
# 删除用户及其所有文章(因为设置了cascade)
db.session.delete(user)
db.session.commit()
print(f"用户及其关联数据已删除")
# 5. 软删除(标记删除而非物理删除)
class SoftDeleteMixin:
"""软删除混入类"""
is_deleted = db.Column(db.Boolean, default=False)
deleted_at = db.Column(db.DateTime)
def soft_delete(self):
"""软删除"""
self.is_deleted = True
self.deleted_at = datetime.utcnow()
db.session.commit()
@classmethod
def query_active(cls):
"""查询未删除的记录"""
return cls.query.filter_by(is_deleted=False)
# 6. 删除关联表中的记录
# 删除文章的所有分类关联
post = Post.query.get(1)
if post:
post.categories = [] # 清空分类关联
db.session.commit()
print(f"文章分类关联已清空")
# 或者直接从关联表删除
from sqlalchemy import delete as sql_delete
stmt = sql_delete(post_categories).where(
post_categories.c.post_id == 1
)
db.session.execute(stmt)
db.session.commit()
return deleted_count
def transaction_examples():
"""事务管理示例"""
# 1. 基本事务
try:
user1 = User(username='user1', email='user1@example.com')
user2 = User(username='user2', email='user2@example.com')
db.session.add(user1)
db.session.add(user2)
# 模拟一个可能失败的操作
if user1.username == user2.username:
raise ValueError("用户名重复")
db.session.commit()
print("事务提交成功")
except Exception as e:
db.session.rollback()
print(f"事务回滚: {e}")
# 2. 嵌套事务
try:
# 开始事务
user = User(username='nested_user', email='nested@example.com')
db.session.add(user)
# 嵌套操作
try:
post = Post(title='Nested Post', content='...', author_id=user.id)
db.session.add(post)
db.session.commit()
print("嵌套事务提交成功")
except Exception as e:
db.session.rollback()
print(f"嵌套事务回滚: {e}")
raise # 重新抛出异常
except Exception as e:
db.session.rollback()
print(f"外层事务回滚: {e}")
# 3. 使用上下文管理器
from contextlib import contextmanager
@contextmanager
def transaction():
"""事务上下文管理器"""
try:
yield
db.session.commit()
except Exception as e:
db.session.rollback()
raise e
# 使用事务上下文管理器
with transaction():
user = User(username='ctx_user', email='ctx@example.com')
post = Post(title='Ctx Post', content='...', author_id=user.id)
db.session.add_all([user, post])
print("上下文管理器事务完成")
# 4. 保存点(复杂事务)
try:
user1 = User(username='savepoint1', email='sp1@example.com')
db.session.add(user1)
# 创建保存点
savepoint = db.session.begin_nested()
try:
user2 = User(username='savepoint2', email='sp2@example.com')
db.session.add(user2)
# 模拟失败
if True: # 测试用,总是触发回滚到保存点
raise ValueError("保存点内的操作失败")
savepoint.commit()
except Exception as e:
savepoint.rollback()
print(f"保存点回滚: {e}")
# 外层事务继续
db.session.commit()
print("主事务提交成功")
except Exception as e:
db.session.rollback()
print(f"主事务回滚: {e}")
def pagination_examples():
"""分页查询示例"""
# 1. 使用SQLAlchemy的分页器
def get_users_page(page=1, per_page=10):
"""获取用户分页"""
pagination = User.query.filter_by(is_active=True).order_by(
User.created_at.desc()
).paginate(
page=page,
per_page=per_page,
error_out=False # 如果page超出范围不抛出异常
)
return {
'users': pagination.items,
'page': pagination.page,
'per_page': pagination.per_page,
'total': pagination.total,
'pages': pagination.pages,
'has_prev': pagination.has_prev,
'has_next': pagination.has_next,
'prev_num': pagination.prev_num,
'next_num': pagination.next_num
}
# 2. 手动实现分页
def manual_pagination(page=1, per_page=10):
"""手动分页"""
offset = (page - 1) * per_page
users = User.query.order_by(User.id).offset(offset).limit(per_page).all()
total = User.query.count()
return {
'users': users,
'page': page,
'per_page': per_page,
'total': total,
'pages': (total + per_page - 1) // per_page
}
# 3. 在视图函数中使用分页
from flask import request, render_template
@app.route('/users')
def list_users():
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 10, type=int)
pagination = User.query.paginate(page=page, per_page=per_page)
return render_template('users/list.html', pagination=pagination)
# 4. 复杂查询的分页
def get_posts_by_category(category_id, page=1, per_page=10):
"""按分类获取文章分页"""
query = Post.query.join(
post_categories
).filter(
post_categories.c.category_id == category_id,
Post.is_published == True
).order_by(
Post.created_at.desc()
)
pagination = query.paginate(page=page, per_page=per_page)
return pagination
# 5. 分页模板示例
'''
<!-- templates/users/list.html -->
@% for user in pagination.items %>
<div>{{ user.username }}</div>
@% endfor %>
<!-- 分页导航 -->
<nav>
<ul class="pagination">
@% if pagination.has_prev %>
<li class="page-item">
<a class="page-link" href="?page={{ pagination.prev_num }}">上一页</a>
</li>
@% endif %>
@% for page_num in pagination.iter_pages() %>
@% if page_num %>
<li class="page-item {{ 'active' if page_num == pagination.page else '' }}">
<a class="page-link" href="?page={{ page_num }}">{{ page_num }}</a>
</li>
@% else %>
<li class="page-item disabled">
<span class="page-link">...</span>
</li>
@% endif %>
@% endfor %>
@% if pagination.has_next %>
<li class="page-item">
<a class="page-link" href="?page={{ pagination.next_num }}">下一页</a>
</li>
@% endif %>
</ul>
</nav>
'''
return get_users_page(1, 10)
def advanced_query_techniques():
"""高级查询技巧"""
from sqlalchemy import and_, or_, not_
from sqlalchemy.orm import aliased
# 1. 使用and_, or_, not_组合条件
complex_query = User.query.filter(
or_(
and_(User.age >= 18, User.age <= 30),
User.is_active == True
),
not_(User.username.like('%admin%'))
).all()
# 2. 子查询
from sqlalchemy import select
# 创建子查询:文章数量大于平均值的用户
subquery = select([User.id]).join(Post).group_by(User.id).having(
func.count(Post.id) > select([func.avg(func.count(Post.id))]).select_from(
User
).join(Post).group_by(User.id).as_scalar()
).alias()
active_authors = User.query.filter(User.id.in_(subquery)).all()
# 3. 使用别名
UserAlias = aliased(User)
query = db.session.query(User, UserAlias).filter(
User.id != UserAlias.id,
User.email == UserAlias.email
).all()
# 4. 窗口函数(需要数据库支持)
from sqlalchemy import over
# 给用户按年龄排名
ranked_users = db.session.query(
User,
func.row_number().over(
order_by=User.age.desc(),
partition_by=User.is_active
).label('rank')
).all()
# 5. 递归查询(用于树状结构)
class Category(db.Model):
# ... 其他字段
parent_id = db.Column(db.Integer, db.ForeignKey('categories.id'))
parent = db.relationship('Category', remote_side=[id], backref='children')
# 查询某个分类的所有子分类
def get_all_subcategories(category_id):
from sqlalchemy.orm import aliased
CategoryAlias = aliased(Category)
# 使用递归CTE(需要数据库支持)
# 这里简化处理:多次查询
all_categories = []
current_level = [category_id]
while current_level:
next_level = Category.query.filter(
Category.parent_id.in_(current_level)
).all()
all_categories.extend(next_level)
current_level = [cat.id for cat in next_level]
return all_categories
# 6. 全文搜索(需要数据库支持全文索引)
def fulltext_search(search_term):
"""全文搜索示例"""
# 使用LIKE模拟
results = Post.query.filter(
or_(
Post.title.ilike(f'%{search_term}%'),
Post.content.ilike(f'%{search_term}%')
)
).all()
# 如果使用PostgreSQL,可以使用更高级的全文搜索
# from sqlalchemy import text
# results = Post.query.filter(
# text("to_tsvector('english', title || ' ' || content) @@ plainto_tsquery(:query)")
# ).params(query=search_term).all()
return results
return complex_query
def performance_optimization():
"""性能优化技巧"""
from sqlalchemy.orm import joinedload, selectinload, subqueryload
# 1. 解决N+1查询问题
# ❌ 不好的做法:每次循环都会查询数据库
users = User.query.all()
for user in users:
posts = user.posts.all() # N+1查询!
# ✅ 好的做法:使用joinedload急切加载
users_with_posts = User.query.options(joinedload(User.posts)).all()
for user in users_with_posts:
posts = user.posts # 已加载,不会再次查询
# 2. 使用selectinload(适合一对多关系)
users = User.query.options(selectinload(User.posts)).all()
# 3. 只选择需要的字段
# ❌ 选择所有字段
users = User.query.all()
# ✅ 只选择需要的字段
users_light = db.session.query(User.id, User.username, User.email).all()
# 4. 使用索引优化查询
# 确保常用查询字段有索引
# 在模型定义中添加:index=True
# 5. 批量操作代替循环
# ❌ 循环更新
users = User.query.all()
for user in users:
user.login_count += 1
db.session.commit()
# ✅ 批量更新
User.query.update({'login_count': User.login_count + 1})
db.session.commit()
# 6. 使用数据库的延迟约束
# 在大量插入时,可以暂时禁用外键检查
if app.config['SQLALCHEMY_DATABASE_URI'].startswith('sqlite'):
db.session.execute('PRAGMA foreign_keys = OFF')
# 执行批量插入
db.session.execute('PRAGMA foreign_keys = ON')
# 7. 分页而不是一次性加载所有数据
# 总是使用分页处理大量数据
# 8. 使用缓存
from flask_caching import Cache
cache = Cache(app, config={'CACHE_TYPE': 'simple'})
@cache.memoize(timeout=300) # 缓存5分钟
def get_active_users_count():
return User.query.filter_by(is_active=True).count()
return users_with_posts
class DatabaseUtils:
"""数据库工具类"""
@staticmethod
def get_or_create(model, **kwargs):
"""获取或创建记录"""
instance = model.query.filter_by(**kwargs).first()
if instance:
return instance, False # False表示已存在
instance = model(**kwargs)
db.session.add(instance)
db.session.commit()
return instance, True # True表示新创建的
@staticmethod
def update_or_create(model, defaults=None, **kwargs):
"""更新或创建记录"""
instance = model.query.filter_by(**kwargs).first()
if instance:
# 更新现有记录
for key, value in (defaults or {}).items():
setattr(instance, key, value)
created = False
else:
# 创建新记录
params = kwargs.copy()
params.update(defaults or {})
instance = model(**params)
created = True
db.session.add(instance)
db.session.commit()
return instance, created
@staticmethod
def bulk_upsert(model, data_list, conflict_fields):
"""批量插入或更新"""
# 简化实现,实际应用中可能需要更复杂的逻辑
for data in data_list:
filter_args = {field: data[field] for field in conflict_fields}
instance = model.query.filter_by(**filter_args).first()
if instance:
for key, value in data.items():
setattr(instance, key, value)
else:
instance = model(**data)
db.session.add(instance)
db.session.commit()
@staticmethod
def export_to_csv(model, filepath):
"""导出数据到CSV"""
import csv
records = model.query.all()
if not records:
return False
# 获取字段名
fields = [column.name for column in model.__table__.columns]
with open(filepath, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(fields)
for record in records:
row = [getattr(record, field) for field in fields]
writer.writerow(row)
return True
@staticmethod
def import_from_csv(model, filepath):
"""从CSV导入数据"""
import csv
with open(filepath, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
instance = model(**row)
db.session.add(instance)
db.session.commit()
return True
# 使用示例
user, created = DatabaseUtils.get_or_create(
User,
username='test_user',
defaults={'email': 'test@example.com', 'age': 25}
)