1
1
mirror of https://github.com/ZeroCatDev/ClassworksKV.git synced 2025-12-07 13:03:09 +00:00
ClassworksKV/utils/tokenManager.js
SunWuyuan 2ab90ffebc
feat: Implement Refresh Token system with enhanced security and user experience
- Added refresh token support in the account model with new fields: refreshToken, refreshTokenExpiry, and tokenVersion.
- Created a new token management utility (utils/tokenManager.js) for generating and verifying access and refresh tokens.
- Updated JWT utility (utils/jwt.js) to maintain backward compatibility while introducing new token generation methods.
- Enhanced middleware for JWT authentication to support new token types and automatic token refreshing.
- Expanded API endpoints in routes/accounts.js to include refresh token functionality, logout options, and token info retrieval.
- Introduced automatic token refresh mechanism in the front-end integration examples.
- Comprehensive migration checklist and documentation for the new refresh token system.
- Added database migration script to accommodate new fields in the Account table.
2025-11-02 09:48:03 +08:00

293 lines
7.4 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import jwt from 'jsonwebtoken';
import crypto from 'crypto';
import { PrismaClient } from '@prisma/client';
const prisma = new PrismaClient();
// Token 配置
const ACCESS_TOKEN_SECRET = process.env.JWT_SECRET || 'your-access-token-secret-change-this-in-production';
const REFRESH_TOKEN_SECRET = process.env.REFRESH_TOKEN_SECRET || 'your-refresh-token-secret-change-this-in-production';
// Token 过期时间配置
const ACCESS_TOKEN_EXPIRES_IN = process.env.ACCESS_TOKEN_EXPIRES_IN || '15m'; // 15分钟
const REFRESH_TOKEN_EXPIRES_IN = process.env.REFRESH_TOKEN_EXPIRES_IN || '7d'; // 7天
// JWT 算法配置
const JWT_ALG = (process.env.JWT_ALG || 'HS256').toUpperCase();
// RS256 密钥对如果使用RSA算法
const ACCESS_TOKEN_PRIVATE_KEY = process.env.ACCESS_TOKEN_PRIVATE_KEY?.replace(/\\n/g, '\n');
const ACCESS_TOKEN_PUBLIC_KEY = process.env.ACCESS_TOKEN_PUBLIC_KEY?.replace(/\\n/g, '\n');
const REFRESH_TOKEN_PRIVATE_KEY = process.env.REFRESH_TOKEN_PRIVATE_KEY?.replace(/\\n/g, '\n');
const REFRESH_TOKEN_PUBLIC_KEY = process.env.REFRESH_TOKEN_PUBLIC_KEY?.replace(/\\n/g, '\n');
/**
* 获取签名和验证密钥
*/
function getKeys(tokenType = 'access') {
if (JWT_ALG === 'RS256') {
const privateKey = tokenType === 'access' ? ACCESS_TOKEN_PRIVATE_KEY : REFRESH_TOKEN_PRIVATE_KEY;
const publicKey = tokenType === 'access' ? ACCESS_TOKEN_PUBLIC_KEY : REFRESH_TOKEN_PUBLIC_KEY;
if (!privateKey || !publicKey) {
throw new Error(`RS256 需要同时提供 ${tokenType.toUpperCase()}_TOKEN_PRIVATE_KEY 与 ${tokenType.toUpperCase()}_TOKEN_PUBLIC_KEY`);
}
return { signKey: privateKey, verifyKey: publicKey };
}
// 默认 HS256
const secret = tokenType === 'access' ? ACCESS_TOKEN_SECRET : REFRESH_TOKEN_SECRET;
return { signKey: secret, verifyKey: secret };
}
/**
* 生成访问令牌
*/
export function generateAccessToken(account) {
const { signKey } = getKeys('access');
const payload = {
type: 'access',
accountId: account.id,
provider: account.provider,
email: account.email,
name: account.name,
avatarUrl: account.avatarUrl,
tokenVersion: account.tokenVersion || 1,
};
return jwt.sign(payload, signKey, {
expiresIn: ACCESS_TOKEN_EXPIRES_IN,
algorithm: JWT_ALG,
issuer: 'ClassworksKV',
audience: 'classworks-client',
});
}
/**
* 生成刷新令牌
*/
export function generateRefreshToken(account) {
const { signKey } = getKeys('refresh');
const payload = {
type: 'refresh',
accountId: account.id,
tokenVersion: account.tokenVersion || 1,
// 添加随机字符串增加安全性
jti: crypto.randomBytes(16).toString('hex'),
};
return jwt.sign(payload, signKey, {
expiresIn: REFRESH_TOKEN_EXPIRES_IN,
algorithm: JWT_ALG,
issuer: 'ClassworksKV',
audience: 'classworks-client',
});
}
/**
* 验证访问令牌
*/
export function verifyAccessToken(token) {
const { verifyKey } = getKeys('access');
try {
const decoded = jwt.verify(token, verifyKey, {
algorithms: [JWT_ALG],
issuer: 'ClassworksKV',
audience: 'classworks-client',
});
if (decoded.type !== 'access') {
throw new Error('Invalid token type');
}
return decoded;
} catch (error) {
throw error;
}
}
/**
* 验证刷新令牌
*/
export function verifyRefreshToken(token) {
const { verifyKey } = getKeys('refresh');
try {
const decoded = jwt.verify(token, verifyKey, {
algorithms: [JWT_ALG],
issuer: 'ClassworksKV',
audience: 'classworks-client',
});
if (decoded.type !== 'refresh') {
throw new Error('Invalid token type');
}
return decoded;
} catch (error) {
throw error;
}
}
/**
* 生成令牌对(访问令牌 + 刷新令牌)
*/
export async function generateTokenPair(account) {
const accessToken = generateAccessToken(account);
const refreshToken = generateRefreshToken(account);
// 计算刷新令牌过期时间
const refreshTokenExpiry = new Date();
const expiresInMs = parseExpirationToMs(REFRESH_TOKEN_EXPIRES_IN);
refreshTokenExpiry.setTime(refreshTokenExpiry.getTime() + expiresInMs);
// 更新数据库中的刷新令牌
await prisma.account.update({
where: { id: account.id },
data: {
refreshToken,
refreshTokenExpiry,
updatedAt: new Date(),
},
});
return {
accessToken,
refreshToken,
accessTokenExpiresIn: ACCESS_TOKEN_EXPIRES_IN,
refreshTokenExpiresIn: REFRESH_TOKEN_EXPIRES_IN,
};
}
/**
* 刷新访问令牌
*/
export async function refreshAccessToken(refreshToken) {
try {
// 验证刷新令牌
const decoded = verifyRefreshToken(refreshToken);
// 从数据库获取账户信息
const account = await prisma.account.findUnique({
where: { id: decoded.accountId },
});
if (!account) {
throw new Error('Account not found');
}
// 验证刷新令牌是否匹配
if (account.refreshToken !== refreshToken) {
throw new Error('Invalid refresh token');
}
// 验证刷新令牌是否过期
if (account.refreshTokenExpiry && account.refreshTokenExpiry < new Date()) {
throw new Error('Refresh token expired');
}
// 验证令牌版本
if (account.tokenVersion !== decoded.tokenVersion) {
throw new Error('Token version mismatch');
}
// 生成新的访问令牌
const newAccessToken = generateAccessToken(account);
return {
accessToken: newAccessToken,
accessTokenExpiresIn: ACCESS_TOKEN_EXPIRES_IN,
account: {
id: account.id,
provider: account.provider,
email: account.email,
name: account.name,
avatarUrl: account.avatarUrl,
},
};
} catch (error) {
throw error;
}
}
/**
* 撤销所有令牌(登出所有设备)
*/
export async function revokeAllTokens(accountId) {
await prisma.account.update({
where: { id: accountId },
data: {
tokenVersion: { increment: 1 },
refreshToken: null,
refreshTokenExpiry: null,
updatedAt: new Date(),
},
});
}
/**
* 撤销当前刷新令牌(登出当前设备)
*/
export async function revokeRefreshToken(accountId) {
await prisma.account.update({
where: { id: accountId },
data: {
refreshToken: null,
refreshTokenExpiry: null,
updatedAt: new Date(),
},
});
}
/**
* 解析过期时间字符串为毫秒
*/
function parseExpirationToMs(expiresIn) {
if (typeof expiresIn === 'number') {
return expiresIn * 1000;
}
const match = expiresIn.match(/^(\d+)([smhd])$/);
if (!match) {
throw new Error('Invalid expiration format');
}
const value = parseInt(match[1]);
const unit = match[2];
switch (unit) {
case 's': return value * 1000;
case 'm': return value * 60 * 1000;
case 'h': return value * 60 * 60 * 1000;
case 'd': return value * 24 * 60 * 60 * 1000;
default: throw new Error('Invalid time unit');
}
}
/**
* 验证账户并检查令牌版本
*/
export async function validateAccountToken(decoded) {
const account = await prisma.account.findUnique({
where: { id: decoded.accountId },
});
if (!account) {
throw new Error('Account not found');
}
// 验证令牌版本
if (account.tokenVersion !== decoded.tokenVersion) {
throw new Error('Token version mismatch');
}
return account;
}
// 向后兼容的导出
export const signToken = generateAccessToken;
export const verifyToken = verifyAccessToken;
export const generateAccountToken = generateAccessToken;