diff --git a/middleware/auth.js b/middleware/auth.js index 742eeac..9ba6878 100644 --- a/middleware/auth.js +++ b/middleware/auth.js @@ -1,120 +1,221 @@ import { siteKey } from "../config.js"; import AppError from "../utils/errors.js"; -import { PrismaClient } from '@prisma/client'; +import { PrismaClient } from "@prisma/client"; const prisma = new PrismaClient(); export const ACCESS_TYPES = { - NO_PASSWORD_WRITABLE: 'NO_PASSWORD_WRITABLE', - NO_PASSWORD_READABLE: 'NO_PASSWORD_READABLE', - NO_PASSWORD_UNREADABLE: 'NO_PASSWORD_UNREADABLE' + PUBLIC: "PUBLIC", + PROTECTED: "PROTECTED", + PRIVATE: "PRIVATE", }; export const checkSiteKey = (req, res, next) => { - const providedKey = req.headers['x-site-key']; + if (!siteKey) { + return next(); + } - if (!siteKey) { - return next(); - } + const providedKey = + req.headers["x-site-key"] || req.query.sitekey || req.body?.sitekey; + if (!providedKey || providedKey !== siteKey) { + return res.status(401).json({ + statusCode: 401, + message: "此服务器已开启站点密钥验证,请提供有效的站点密钥", + }); + } - if (!providedKey || providedKey !== siteKey) { - const error = AppError.createError( - AppError.HTTP_STATUS.UNAUTHORIZED, - "此服务器已开启站点密钥验证,请在请求头中添加 x-site-key 以继续访问" - ); - return res.status(error.statusCode).json(error); - } - - next(); + next(); }; async function getOrCreateDevice(uuid, className) { + try { let device = await prisma.device.findUnique({ - where: { uuid } + where: { uuid }, }); if (!device) { + try { device = await prisma.device.create({ - data: { - uuid, - name: className || null, - accessType: ACCESS_TYPES.NO_PASSWORD_WRITABLE - } + data: { + uuid, + name: className || null, + accessType: ACCESS_TYPES.PUBLIC, + }, }); + } catch (error) { + // 如果是唯一约束错误(并发创建),重新获取设备 + if (error.code === "P2002") { + device = await prisma.device.findUnique({ + where: { uuid }, + }); + } else { + throw error; + } + } + } + + // 如果设备没有密码,自动设为public + if ( + device && + !device.password && + device.accessType !== ACCESS_TYPES.PUBLIC + ) { + device = await prisma.device.update({ + where: { uuid }, + data: { accessType: ACCESS_TYPES.PUBLIC }, + }); } return device; -} - -async function validatePassword(device, password) { - if (!device.password) return true; - return device.password === password; + } catch (error) { + console.error("Error in getOrCreateDevice:", error); + throw error; + } } export const authMiddleware = async (req, res, next) => { - const { namespace } = req.params; - const { password } = req.body; + const { namespace } = req.params; + const password = + req.headers["x-namespace-password"] || + req.query.password || + req.body?.password; - try { - const device = await getOrCreateDevice(namespace, req.body.className); - req.device = device; + try { + const device = await getOrCreateDevice(namespace, req.body?.className); + req.device = device; - if (device.password && !await validatePassword(device, password)) { - return res.status(401).json({ error: 'Invalid password' }); - } - - next(); - } catch (error) { - console.error('Auth middleware error:', error); - res.status(500).json({ error: 'Internal server error' }); + if (device.password && password !== device.password) { + return res.status(401).json({ + statusCode: 401, + message: "设备密码验证失败", + }); } + + next(); + } catch (error) { + console.error("Auth middleware error:", error); + res.status(500).json({ + statusCode: 500, + message: "服务器内部错误", + }); + } }; export const readAuthMiddleware = async (req, res, next) => { - const { namespace } = req.params; + const { namespace } = req.params; + const password = + req.headers["x-namespace-password"] || + req.query.password || + req.body?.password; - try { - const device = await getOrCreateDevice(namespace); - req.device = device; + try { + const device = await getOrCreateDevice(namespace); + res.locals.device = device; - if (device.accessType === ACCESS_TYPES.NO_PASSWORD_UNREADABLE) { - return res.status(403).json({ error: 'Device is not readable' }); - } - - if (device.accessType === ACCESS_TYPES.NO_PASSWORD_READABLE) { - return next(); - } - - if (device.password) { - const { password } = req.body; - if (!await validatePassword(device, password)) { - return res.status(401).json({ error: 'Invalid password' }); - } - } - - next(); - } catch (error) { - console.error('Read auth middleware error:', error); - res.status(500).json({ error: 'Internal server error' }); + // PUBLIC and PROTECTED devices are always readable + if ([ACCESS_TYPES.PUBLIC, ACCESS_TYPES.PROTECTED].includes(device.accessType)) { + return next(); } + + // For PRIVATE devices, require password + if (!device.password || password !== device.password) { + return res.status(401).json({ + statusCode: 401, + message: "设备密码验证失败", + }); + } + + next(); + } catch (error) { + console.error("Read auth middleware error:", error); + res.status(500).json({ + statusCode: 500, + message: "服务器内部错误", + }); + } }; export const writeAuthMiddleware = async (req, res, next) => { - const { namespace } = req.params; + const { namespace } = req.params; + const password = + req.headers["x-namespace-password"] || + req.query.password || + req.body?.password; - try { - const device = await getOrCreateDevice(namespace); - req.device = device; + try { + const device = await getOrCreateDevice(namespace); + res.locals.device = device; - if (device.password) { - const { password } = req.body; - if (!await validatePassword(device, password)) { - return res.status(401).json({ error: 'Invalid password' }); - } - } - - next(); - } catch (error) { - console.error('Write auth middleware error:', error); - res.status(500).json({ error: 'Internal server error' }); + // PUBLIC devices are always writable + if (device.accessType === ACCESS_TYPES.PUBLIC) { + return next(); } -}; \ No newline at end of file + + // For PROTECTED and PRIVATE devices, require password + if (!device.password || password !== device.password) { + return res.status(401).json({ + statusCode: 401, + message: "设备密码验证失败", + }); + } + + next(); + } catch (error) { + console.error("Write auth middleware error:", error); + res.status(500).json({ + statusCode: 500, + message: "服务器内部错误", + }); + } +}; + +export const removePasswordMiddleware = async (req, res, next) => { + const { namespace } = req.params; + const password = + req.headers["x-namespace-password"] || + req.query.password || + req.body?.password; + const providedKey = + req.headers["x-site-key"] || req.query.sitekey || req.body?.sitekey; + + try { + const device = await getOrCreateDevice(namespace); + req.device = device; + + // 验证站点令牌(如果设置了的话) + if (siteKey && (!providedKey || providedKey !== siteKey)) { + return res.status(401).json({ + statusCode: 401, + message: "此服务器已开启站点密钥验证,请提供有效的站点密钥", + }); + } + + // 验证设备密码 + if (device.password) { + if (!password || password !== device.password) { + return res.status(401).json({ + statusCode: 401, + message: "设备密码验证失败", + }); + } + } else { + return res.status(400).json({ + statusCode: 400, + message: "设备当前没有设置密码", + }); + } + + // 更新设备,移除密码 + await prisma.device.update({ + where: { uuid: namespace }, + data: { password: null }, + }); + + res.json({ message: "密码已成功移除" }); + } catch (error) { + console.error("Remove password middleware error:", error); + res.status(500).json({ + statusCode: 500, + message: "服务器内部错误", + }); + } +}; diff --git a/middleware/rateLimiter.js b/middleware/rateLimiter.js index da6d466..45db6b4 100644 --- a/middleware/rateLimiter.js +++ b/middleware/rateLimiter.js @@ -26,7 +26,7 @@ export const globalLimiter = rateLimit({ // API限速器 export const apiLimiter = rateLimit({ windowMs: 1 * 60 * 1000, // 1分钟 - limit: 20, // 每个IP在windowMs时间内最多允许20个请求 + limit: 50, // 每个IP在windowMs时间内最多允许50个请求 standardHeaders: "draft-7", legacyHeaders: false, message: "API请求过于频繁,请稍后再试", @@ -38,7 +38,7 @@ export const apiLimiter = rateLimit({ // 写操作限速器(更严格) export const writeLimiter = rateLimit({ windowMs: 1 * 60 * 1000, // 1分钟 - limit: 10, // 每个IP在windowMs时间内最多允许10个写操作 + limit: 20, // 每个IP在windowMs时间内最多允许20个写操作 standardHeaders: "draft-7", legacyHeaders: false, message: "写操作请求过于频繁,请稍后再试", @@ -50,7 +50,7 @@ export const writeLimiter = rateLimit({ // 删除操作限速器(最严格) export const deleteLimiter = rateLimit({ windowMs: 1 * 60 * 1000, // 5分钟 - limit: 1, // 每个IP在windowMs时间内最多允许5个删除操作 + limit: 10, // 每个IP在windowMs时间内最多允许10个删除操作 standardHeaders: "draft-7", legacyHeaders: false, message: "删除操作请求过于频繁,请稍后再试", @@ -74,7 +74,7 @@ export const authLimiter = rateLimit({ // 批量操作限速器(比写操作更严格) export const batchLimiter = rateLimit({ windowMs: 1 * 60 * 1000, // 5分钟 - limit: 10, // 每个IP在windowMs时间内最多允许5个批量操作 + limit: 10, // 每个IP在windowMs时间内最多允许10个批量操作 standardHeaders: "draft-7", legacyHeaders: false, message: "批量操作请求过于频繁,请稍后再试", diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 06911c0..1dd28b2 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -7,6 +7,12 @@ datasource db { url = env("DATABASE_URL") } +enum AccessType { + PUBLIC // No password required for read/write + PROTECTED // No password for read, password for write + PRIVATE // Password required for read/write +} + model KVStore { namespace String @db.Char(36) key String @@ -22,7 +28,7 @@ model Device { uuid String @id @db.Char(36) password String? name String? - accessType String @default("NO_PASSWORD_WRITABLE") + accessType AccessType @default(PUBLIC) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt } diff --git a/routes/kv.js b/routes/kv.js index 62a2bbc..0bb3a5e 100644 --- a/routes/kv.js +++ b/routes/kv.js @@ -5,10 +5,17 @@ import { checkSiteKey } from "../middleware/auth.js"; import { v4 as uuidv4 } from "uuid"; import errors from "../utils/errors.js"; import { PrismaClient } from "@prisma/client"; -import { readAuthMiddleware, writeAuthMiddleware } from "../middleware/auth.js"; +import { + readAuthMiddleware, + writeAuthMiddleware, + removePasswordMiddleware, +} from "../middleware/auth.js"; const prisma = new PrismaClient(); +// 定义有效的访问类型 +const VALID_ACCESS_TYPES = ["PUBLIC", "PROTECTED", "PRIVATE"]; + // 检查是否为受限UUID的中间件 const checkRestrictedUUID = (req, res, next) => { const restrictedUUID = "00000000-0000-4000-8000-000000000000"; @@ -25,20 +32,45 @@ router.use(checkSiteKey); // Get device info router.get( "/:namespace/_info", + checkRestrictedUUID, readAuthMiddleware, errors.catchAsync(async (req, res) => { - try { - const { device } = req; - res.json({ - uuid: device.uuid, - name: device.name, - accessType: device.accessType, - hasPassword: !!device.password, + const device = res.locals.device; + if (!device) { + return res.status(404).json({ + statusCode: 404, + message: "设备不存在", }); - } catch (error) { - console.error("Error getting device info:", error); - res.status(500).json({ error: "Internal server error" }); } + res.json({ + uuid: device.uuid, + name: device.name, + accessType: device.accessType, + hasPassword: !!device.password, + }); + }) +); + +// Get device info +router.get( + "/:namespace/_check", + checkRestrictedUUID, + writeAuthMiddleware, + errors.catchAsync(async (req, res) => { + const device = res.locals.device; + if (!device) { + return res.status(404).json({ + statusCode: 404, + message: "设备不存在", + }); + } + res.json({ + status: 'success', + uuid: device.uuid, + name: device.name, + accessType: device.accessType, + hasPassword: !!device.password, + }); }) ); @@ -46,13 +78,13 @@ router.get( router.post( "/:namespace/_password", writeAuthMiddleware, - errors.catchAsync(async (req, res) => { + errors.catchAsync(async (req, res, next) => { const { newPassword, oldPassword } = req.body; - const { device } = req; + const device = res.locals.device; try { if (device.password && oldPassword !== device.password) { - return res.status(401).json({ error: "Invalid old password" }); + return next(errors.createError(500, "密码错误")); } await prisma.device.update({ @@ -60,10 +92,9 @@ router.post( data: { password: newPassword }, }); - res.json({ message: "Password updated successfully" }); + res.json({ message: "密码已成功修改" }); } catch (error) { - console.error("Error updating password:", error); - res.status(500).json({ error: "Internal server error" }); + return next(errors.createError(500, "无法修改密码")); } }) ); @@ -74,27 +105,38 @@ router.put( writeAuthMiddleware, errors.catchAsync(async (req, res) => { const { name, accessType } = req.body; - const { device } = req; + const device = res.locals.device; - try { - const updatedDevice = await prisma.device.update({ - where: { uuid: device.uuid }, - data: { - name: name || device.name, - accessType: accessType || device.accessType, - }, + // 验证 accessType + if (accessType && !VALID_ACCESS_TYPES.includes(accessType)) { + return res.status(400).json({ + error: `Invalid access type. Must be one of: ${VALID_ACCESS_TYPES.join(", ")}`, }); - - res.json({ - uuid: updatedDevice.uuid, - name: updatedDevice.name, - accessType: updatedDevice.accessType, - hasPassword: !!updatedDevice.password, - }); - } catch (error) { - console.error("Error updating device info:", error); - res.status(500).json({ error: "Internal server error" }); } + + const updatedDevice = await prisma.device.update({ + where: { uuid: device.uuid }, + data: { + name: name || device.name, + accessType: accessType || device.accessType, + }, + }); + + res.json({ + uuid: updatedDevice.uuid, + name: updatedDevice.name, + accessType: updatedDevice.accessType, + hasPassword: !!updatedDevice.password, + }); + }) +); + +// Remove device password +router.delete( + "/:namespace/_password", + removePasswordMiddleware, + errors.catchAsync(async (req, res) => { + res.json({ message: "密码已成功移除" }); }) );