大家好,我是ICodeWR。今天要记录的是Streamlit高级功能。
学习目标
掌握性能优化、主题定制、自定义组件开发等高级功能
实现数据库集成、实时通信、安全认证等扩展功能
构建企业级数据管理平台并完成生产级部署
项目效果图
1. 核心高级功能
1.1 缓存优化
@st.cache_data(ttl=3600, show_spinner=True)
def process_big_data(path):
# 大数据处理逻辑
return result
@st.cache_resource
def get_db_connection():
return sqlite3.connect('data.db')
功能
cache_data:
缓存数据处理结果(适合DataFrame等可变数据)
cache_resource:
缓存不可变资源(数据库连接、模型等)
参数
ttl:
缓存有效期(秒)
show_spinner:
显示加载指示器
1.2 主题定制
创建.streamlit/config.toml:
[theme]
base = "dark"
primaryColor = "#FF4B4B"
backgroundColor = "#0E1117"
secondaryBackgroundColor = "#262730"
textColor = "#FAFAFA"
font = "sans serif"
[server]
port = 8502
1.3 自定义组件开发
// 前端组件代码(React)
importReactfrom"react"
import { Streamlit} from"streamlit-component-lib"
classMyComponentextendsReact.Component {
render() {
return<button onClick={() => Streamlit.setComponentValue("clicked")}>
Custom Button
</button>
}
}
2. 综合案例:企业数据管理平台
系统架构
企业数据平台/
├── main.py # 主程序
├── pages/
│ ├── 1__仪表盘.py # 数据可视化
│ ├── 2__数据管理.py # 数据库操作
│ └── 3__系统管理.py # 配置管理
├── components/
│ └── tag_input.py # 自定义标签输入组件
├── utils/
│ ├── database.py # 数据库模块
│ ├── user.py # 用户验证模块
│ └── security.py # 密码认证模块
└── .streamlit/
└── config.toml # 主题配置
核心代码实现
main.py(主入口)
import streamlit as st
from utils.user import authenticate
# 初始化会话状态
if"auth"notin st.session_state:
st.session_state.auth = {
"logged_in": False,
"user": None,
"role": "guest"
}
# 登录验证
ifnot st.session_state.auth["logged_in"]:
authenticate()
st.stop()
# 主界面配置
st.set_page_config(
page_title="企业数据平台",
page_icon="",
layout="wide",
initial_sidebar_state="expanded"
)
# 显示登录信息
st.sidebar.success(f"欢迎, {st.session_state.auth['user']} ({st.session_state.auth['role']})")
# 导航菜单
pages = {
" 数据仪表盘": "pages/1__仪表盘.py",
" 数据管理": "pages/2__数据管理.py",
" 系统管理": "pages/3__系统管理.py"
}
# 只有管理员能看到系统管理
if st.session_state.auth["role"] != "admin":
del pages[" 系统管理"]
selected = st.sidebar.selectbox("导航", list(pages.keys()))
st.switch_page(pages[selected])
components/tag_input.py(自定义组件)
import streamlit as st
import streamlit.components.v1 as components
# 声明自定义组件
deftag_input(label, default=None, key=None):
if default isNone:
default = []
# 组件HTML/JS
component_html = f"""
<div id="tag-input-container">
<style>
.tag-container {{
display: flex;
flex-wrap: wrap;
gap: 5px;
padding: 5px;
border: 1px solid #ccc;
border-radius: 4px;
}}
.tag {{
background-color: #f0f2f6;
padding: 2px 8px;
border-radius: 10px;
display: flex;
align-items: center;
}}
.tag-remove {{
margin-left: 5px;
cursor: pointer;
}}
#tag-input {{
border: none;
outline: none;
flex-grow: 1;
padding: 5px;
}}
</style>
<label>{label}</label>
<div class="tag-container" id="tags-{key}">
<input type="text" id="tag-input-{key}" placeholder="输入标签后按回车...">
</div>
<script>
const container = document.getElementById('tags-{key}');
const input = document.getElementById('tag-input-{key}');
const tags = {default};
function updateTags() {{
Streamlit.setComponentValue(tags);
}}
function createTag(label) {{
const tagDiv = document.createElement('div');
tagDiv.className = 'tag';
const span = document.createElement('span');
span.textContent = label;
const removeBtn = document.createElement('span');
removeBtn.className = 'tag-remove';
removeBtn.innerHTML = '×';
removeBtn.onclick = function() {{
const index = tags.indexOf(label);
if (index !== -1) {{
tags.splice(index, 1);
container.removeChild(tagDiv);
updateTags();
}}
}};
tagDiv.appendChild(span);
tagDiv.appendChild(removeBtn);
return tagDiv;
}}
// 初始化已有标签
tags.forEach(tag => {{
container.insertBefore(createTag(tag), input);
}});
input.addEventListener('keydown', function(e) {{
if (e.key === 'Enter' && input.value.trim() !== '') {{
const tag = input.value.trim();
if (!tags.includes(tag)) {{
tags.push(tag);
container.insertBefore(createTag(tag), input);
updateTags();
}}
input.value = '';
}}
}});
</script>
</div>
"""
# 渲染组件
return components.html(component_html, height=100, key=key)
utils/database.py(数据库模块)
import sqlite3
import pandas as pd
import streamlit as st
from functools import wraps
from utils.security import hash_password
defhandle_db_errors(func):
@wraps(func)
defwrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except sqlite3.Error as e:
st.error(f"数据库错误: {str(e)}")
raise
except Exception as e:
st.error(f"操作失败: {str(e)}")
raise
return wrapper
@st.cache_resource(show_spinner=False)
defget_connection():
"""获取数据库连接"""
conn = sqlite3.connect("data.db", check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
@handle_db_errors
defquery_db(query, params=None, _conn=None):
"""执行查询并返回DataFrame"""
conn = _conn or get_connection()
return pd.read_sql(query, conn, params=params)
@handle_db_errors
defexecute_db(query, params=None, _conn=None):
"""执行非查询SQL语句"""
conn = _conn or get_connection()
cursor = conn.cursor()
cursor.execute(query, params or ())
conn.commit()
return cursor.rowcount
definit_database():
"""初始化数据库表结构"""
conn = get_connection()
# 创建用户表
execute_db("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'viewer',
status TEXT NOT NULL DEFAULT 'active',
region TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""", _conn=conn)
# 创建销售表
execute_db("""
CREATE TABLE IF NOT EXISTS sales (
id INTEGER PRIMARY KEY AUTOINCREMENT,
product_id TEXT NOT NULL,
quantity INTEGER NOT NULL,
amount REAL NOT NULL,
date DATE NOT NULL,
region TEXT NOT NULL
)
""", _conn=conn)
# 添加示例数据
ifnot query_db("SELECT COUNT(*) as count FROM users", _conn=conn).iloc[0]['count']:
execute_db("""
INSERT INTO users (username, password, role, region)
VALUES (?, ?, ?, ?)
""", ("admin", hash_password("admin123"), "admin", "华东"), _conn=conn)
# 添加示例销售数据
import random
from datetime import datetime, timedelta
regions = ["华东", "华北", "华南", "华中", "西部"]
products = ["P1001", "P1002", "P1003", "P1004", "P1005"]
for i inrange(100):
date = datetime.now() - timedelta(days=random.randint(0, 30))
execute_db("""
INSERT INTO sales (product_id, quantity, amount, date, region)
VALUES (?, ?, ?, ?, ?)
""", (
random.choice(products),
random.randint(1, 10),
round(random.uniform(100, 1000), 2),
date.strftime("%Y-%m-%d"),
random.choice(regions)
), _conn=conn)
conn.commit()
# 初始化数据库
init_database()
utils/security.py(密码验证模块)
import bcrypt
defhash_password(password):
"""哈希密码"""
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode(), salt).decode()
defverify_password(input_pw, hashed_pw):
"""验证密码"""
return bcrypt.checkpw(input_pw.encode(), hashed_pw.encode())
utils/user.py(用户验证模块)
import streamlit as st
from utils.database import get_connection
from utils.security import hash_password, verify_password
defauthenticate():
"""用户认证"""
with st.container():
st.title("企业数据平台登录")
cols = st.columns([1, 3, 1])
with cols[1]:
with st.form("login_form"):
username = st.text_input("用户名")
password = st.text_input("密码", type="password")
remember = st.checkbox("记住我")
if st.form_submit_button("登录"):
user = verify_credentials(username, password)
if user:
st.session_state.auth = {
"logged_in": True,
"user": user["username"],
"role": user["role"]
}
st.rerun()
else:
st.error("用户名或密码错误")
defverify_credentials(username, password):
"""验证用户凭证"""
conn = get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT username, password, role FROM users
WHERE username = ? AND status = 'active'
""", (username,))
user = cursor.fetchone()
if user and verify_password(password, user["password"]):
returndict(user)
returnNone
defregister_user(username, password, role="viewer"):
"""注册新用户"""
conn = get_connection()
cursor = conn.cursor()
# 检查用户名是否存在
cursor.execute("SELECT 1 FROM users WHERE username = ?", (username,))
if cursor.fetchone():
raise ValueError("用户名已存在")
# 创建用户
hashed_pw = hash_password(password)
cursor.execute("""
INSERT INTO users (username, password, role)
VALUES (?, ?, ?)
""", (username, hashed_pw, role))
conn.commit()
defget_user_role(username):
"""获取用户角色"""
conn = get_connection()
cursor = conn.cursor()
cursor.execute("SELECT role FROM users WHERE username = ?", (username,))
result = cursor.fetchone()
return result["role"] if result else"guest"
pages/1__仪表盘.py(数据可视化)
import streamlit as st
import pandas as pd
import plotly.express as px
from utils.database import query_db
st.title(" 数据仪表盘")
# 获取数据
@st.cache_data(ttl=300)
defget_dashboard_data():
sales_data = query_db("SELECT * FROM sales WHERE date >= date('now', '-30 days')")
user_data = query_db("SELECT * FROM users")
return sales_data, user_data
sales_df, users_df = get_dashboard_data()
# 指标卡片
col1, col2, col3 = st.columns(3)
col1.metric("总销售额", f"¥{sales_df['amount'].sum():,.2f}", "7.2%")
col2.metric("活跃用户", len(users_df[users_df['status'] == 'active']), "-3.1%")
col3.metric("平均订单", f"¥{sales_df['amount'].mean():,.2f}", "1.8%")
# 图表展示
tab1, tab2, tab3 = st.tabs(["销售趋势", "用户分布", "产品分析"])
with tab1:
fig = px.line(sales_df.groupby('date').sum().reset_index(),
x='date', y='amount',
title="30天销售趋势")
st.plotly_chart(fig, use_container_width=True)
with tab2:
col1, col2 = st.columns(2)
with col1:
st.subheader("用户地域分布")
region_count = users_df['region'].value_counts().reset_index()
st.bar_chart(region_count, x='region', y='count')
with col2:
st.subheader("用户状态")
status_count = users_df['status'].value_counts().reset_index()
fig = px.pie(status_count,
values='count',
names='status',
title='用户状态分布')
st.plotly_chart(fig, use_container_width=True)
with tab3:
product_sales = sales_df.groupby('product_id').agg({
'amount': 'sum',
'quantity': 'sum'
}).reset_index()
fig = px.scatter(product_sales, x='quantity', y='amount', size='amount',
hover_name='product_id', title="产品销量与销售额关系")
st.plotly_chart(fig, use_container_width=True)
pages/2__数据管理.py(数据库操作)
import streamlit as st
import pandas as pd
from utils.database import query_db, get_connection
st.title(" 数据管理")
# 标签输入组件
from components.tag_input import tag_input
# 数据表选择
tables = query_db("SELECT name FROM sqlite_master WHERE type='table'")
selected_table = st.selectbox("选择数据表", tables['name'].tolist())
# 显示表数据
if selected_table:
df = query_db(f"SELECT * FROM {selected_table} LIMIT 1000")
st.dataframe(df, use_container_width=True)
# 数据操作选项卡
tab1, tab2, tab3 = st.tabs(["查询", "编辑", "导入"])
with tab1:
st.subheader("自定义查询")
query = st.text_area("输入SQL查询语句", f"SELECT * FROM {selected_table} LIMIT 100")
if st.button("执行查询"):
try:
result = query_db(query)
st.dataframe(result, use_container_width=True)
except Exception as e:
st.error(f"查询错误: {str(e)}")
with tab2:
if st.session_state.auth["role"] notin ["admin", "editor"]:
st.warning("您没有编辑权限")
else:
st.subheader("数据编辑")
selected_columns = st.multiselect("选择显示的列", df.columns.tolist(), default=df.columns.tolist())
edited_df = st.data_editor(df[selected_columns], use_container_width=True)
if st.button("保存更改"):
conn = get_connection()
try:
edited_df.to_sql(selected_table, conn, if_exists="replace", index=False)
st.success("数据保存成功!")
except Exception as e:
st.error(f"保存失败: {str(e)}")
with tab3:
st.subheader("数据导入")
upload_file = st.file_uploader("上传CSV文件", type=["csv"])
if upload_file:
new_data = pd.read_csv(upload_file)
st.write("预览数据:")
st.dataframe(new_data.head())
if st.button("导入数据"):
conn = get_connection()
try:
new_data.to_sql(selected_table, conn, if_exists="append", index=False)
st.success(f"成功导入 {len(new_data)} 条记录!")
except Exception as e:
st.error(f"导入失败: {str(e)}")
pages/3__系统管理.py(配置管理)
import streamlit as st
import os
from utils.database import get_connection, query_db
from utils.user import register_user
st.title(" 系统管理")
if st.session_state.auth["role"] != "admin":
st.error("您没有访问此页面的权限")
st.stop()
tab1, tab2, tab3 = st.tabs(["用户管理", "系统配置", "数据库维护"])
with tab1:
st.subheader("用户账户管理")
# 显示现有用户
users = st.cache_data(ttl=60)(query_db)("SELECT username, role, created_at FROM users")
st.dataframe(users, use_container_width=True)
# 添加新用户
with st.expander("添加新用户"):
with st.form("user_form"):
username = st.text_input("用户名")
password = st.text_input("密码", type="password")
role = st.selectbox("角色", ["admin", "editor", "viewer"])
if st.form_submit_button("创建用户"):
try:
register_user(username, password, role)
st.success(f"用户 {username} 创建成功!")
st.cache_data.clear()
except Exception as e:
st.error(f"创建失败: {str(e)}")
with tab2:
st.subheader("系统配置")
# 主题配置
st.write("### 主题设置")
current_theme = st.selectbox("选择主题", ["light", "dark"], index=1)
# 性能配置
st.write("### 性能设置")
cache_ttl = st.slider("缓存时间(秒)", 60, 3600, 300)
if st.button("保存配置"):
# 这里应该写入配置文件
st.success("配置已保存 (示例: 实际需要写入配置文件)")
with tab3:
st.subheader("数据库维护")
col1, col2 = st.columns(2)
with col1:
st.write("### 备份数据库")
if st.button("创建备份"):
conn = get_connection()
try:
withopen('data_backup.db', 'wb') as f:
for line in conn.iterdump():
f.write(f'{line}\n'.encode('utf-8'))
st.success("备份创建成功!")
except Exception as e:
st.error(f"备份失败: {str(e)}")
with col2:
st.write("### 恢复数据库")
backup_file = st.file_uploader("上传备份文件", type=["db"])
if backup_file and st.button("恢复数据库"):
try:
withopen('data.db', 'wb') as f:
f.write(backup_file.getvalue())
st.success("数据库恢复成功!")
st.cache_data.clear()
st.cache_resource.clear()
except Exception as e:
st.error(f"恢复失败: {str(e)}")
st.write("### 数据库状态")
db_size = os.path.getsize('data.db') / (1024 * 1024)
st.metric("数据库大小", f"{db_size:.2f} MB")
if st.button("优化数据库"):
conn = get_connection()
try:
conn.execute("VACUUM")
st.success("数据库优化完成!")
except Exception as e:
st.error(f"优化失败: {str(e)}")
3. 高级功能实现
3.2 安全认证系统
# utils/security.py
import bcrypt
import hashlib
defhash_password(password):
salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode(), salt)
defverify_password(input_pw, hashed_pw):
return bcrypt.checkpw(input_pw.encode(), hashed_pw)
defauthenticate():
with st.container():
cols = st.columns([1,3,1])
with cols[1]:
with st.form("登录"):
user = st.text_input("用户名")
pw = st.text_input("密码", type="password")
if st.form_submit_button("登录"):
if verify_credentials(user, pw):
st.session_state.auth = {
"logged_in": True,
"user": user,
"role": get_user_role(user)
}
st.rerun()
5. 运行效果
首页
数据可视化
系统管理
6. 扩展建议
实时数据看板
import time
from pages.仪表盘 import update_realtime_data
if st.button("启动实时模式"):
placeholder = st.empty()
while True:
df = update_realtime_data()
with placeholder.container():
st.line_chart(df)
time.sleep(1)
性能监控
集成Prometheus+Grafana:
from prometheus_client import start_http_server, Counter
REQUEST_COUNTER = Counter('app_requests', 'Total app requests')
start_http_server(9090)
自动化测试
使用pytest编写测试用例:
def test_login():
assert authenticate("admin", "password") == True
assert authenticate("hacker", "123456") == False
微服务集成
通过FastAPI提供REST接口:
from fastapi import FastAPI
app = FastAPI()
@app.get("/api/data")
def get_data():
return query_db("SELECT * FROM sales")
总结
本日志记录了:
Streamlit 高级性能优化技巧
自定义组件开发与主题定制
生产环境部署配置方法
企业级应用的架构设计
完整代码仓库:(https://gitcode.com/ICodeWR/StudyFlow/tree/main/src/streamLib/src/day10)
交流讨论:欢迎在评论区留言!
重要提示:本文主要是记录自己的学习与实践过程,所提内容或者观点仅代表个人意见,只是我以为的,不代表完全正确,不喜请勿关注。
领取专属 10元无门槛券
私享最新 技术干货