367 lines
15 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package edu.whut.trigger.http;
import edu.whut.api.IRAGService;
import edu.whut.api.response.Response;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.jgit.api.Git;
import org.eclipse.jgit.transport.UsernamePasswordCredentialsProvider;
import org.redisson.api.RList;
import org.redisson.api.RedissonClient;
import org.springframework.ai.document.Document;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.core.io.PathResource;
import org.springframework.http.MediaType;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* RAG 服务控制器,实现 IRAGService 接口,提供知识库管理和检索相关的 HTTP 接口
*/
@Slf4j
@RestController
@RequestMapping("/api/v1/rag/")
@CrossOrigin("*")
@RequiredArgsConstructor
public class RAGController implements IRAGService {
// Ollama 对话模型(此处暂无直接使用,但保留注入以便扩展)
private final OllamaChatModel ollamaChatModel;
// 文本拆分器,将长文档切分为合适大小的段落或 Token 块
private final TokenTextSplitter tokenTextSplitter;
// 简易内存向量存储,用于快速测试或小规模存储
private final SimpleVectorStore simpleVectorStore;
// PostgreSQL pgvector 存储,用于持久化和检索嵌入向量
private final PgVectorStore pgVectorStore;
// Redisson 客户端,用于操作 Redis 列表存储 RAG 标签
private final RedissonClient redissonClient;
private final JdbcTemplate jdbcTemplate; // 注入Spring JDBC
/**
* 查询所有已上传的 RAG 标签列表
* GET /api/v1/rag/query_rag_tag_list
*/
@GetMapping("query_rag_tag_list")
@Override
public Response<List<String>> queryRagTagList() {
// 从 Redis 列表获取所有标签
RList<String> elements = redissonClient.getList("ragTag");
// 读一个快照,便于安全日志与返回
List<String> tags;
try {
tags = elements.readAll(); // Redisson 提供的批量读取
} catch (Exception e) {
// 兜底:某些客户端/版本没有 readAll 时使用迭代
log.warn("读取 Redis ragTag 列表使用 readAll 失败,改用迭代读取。", e);
tags = new ArrayList<>();
for (String s : elements) {
tags.add(s);
}
}
// 打印查询到的标签(数量 + 内容)
log.info("查询 RAG 标签列表,数量:{},内容:{}", tags.size(), tags);
return Response.<List<String>>builder()
.code("0000")
.info("调用成功")
.data(elements)
.build();
}
/**
* 查询指定知识库下的所有文件名
* GET /api/v1/rag/knowledge/{ragTag}/files
*/
@GetMapping("knowledge/files/{ragTag}")
public Response<List<String>> getKnowledgeFiles(@PathVariable String ragTag) {
log.info("查询知识库文件列表开始:{}", ragTag);
try {
// 使用原生SQL查询假设表名为 vector_store
String sql = """
SELECT DISTINCT metadata->>'path' as file_path
FROM vector_store
WHERE metadata->>'knowledge' = ?
ORDER BY metadata->>'path'
""";
// 执行查询并获取结果
List<String> filePaths = jdbcTemplate.queryForList(
sql,
new Object[]{ragTag},
String.class
);
log.info("查询知识库文件列表完成:{},共 {} 个文件", ragTag, filePaths.size());
return Response.<List<String>>builder()
.code("0000")
.info("查询成功")
.data(filePaths)
.build();
} catch (Exception e) {
log.error("查询知识库文件列表失败:{}", ragTag, e);
return Response.<List<String>>builder()
.code("9999")
.info("查询失败:" + e.getMessage())
.data(Collections.emptyList())
.build();
}
}
/**
* 注意:可以追加文件到相同的知识库中!
* 上传文件到知识库:
* - 使用 Tika 读取文档内容
* - 进行文本切分并贴上 ragTag 元数据
* - 存储到 pgVectorStore 并更新 Redis 标签列表
* POST /api/v1/rag/file/upload
*/
/**
* 注意:可以追加文件到相同的知识库中!
* 上传文件到知识库:
* - 使用 Tika 读取文档内容
* - 进行文本切分并贴上 ragTag 元数据
* - 存储到 pgVectorStore 并更新 Redis 标签列表
* POST /api/v1/rag/file/upload
*/
@PostMapping(path = "file/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
@Override
public Response<String> uploadFile(
@RequestParam("ragTag") String ragTag,
@RequestParam("file") List<MultipartFile> files,
@RequestParam(value = "filePath", required = false) List<String> filePaths) {
log.info("上传知识库开始:{}", ragTag);
log.info("待处理文件数量:{},传入路径数量:{}", files.size(), filePaths != null ? filePaths.size() : 0);
// 使用带索引的 for 循环,保证与 filePath 一一对应
for (int i = 0; i < files.size(); i++) {
MultipartFile file = files.get(i);
String originalFilename = file.getOriginalFilename();
// ===== 路径标准化处理 =====
// 1. 获取原始路径(优先使用前端传入的路径,否则使用文件名)
String rawPath = (filePaths != null && i < filePaths.size() && filePaths.get(i) != null && !filePaths.get(i).isBlank())
? filePaths.get(i)
: (originalFilename != null ? originalFilename : "");
// 2. 标准化路径格式(统一用 / 分隔符,去除开头多余的 ./ 或 /
String normalizedPath = rawPath
.replace("\\", "/") // 统一使用正斜杠
.replaceAll("^[/.]+", "") // 去除开头的 ./ 或 /
.replaceAll("/+", "/"); // 替换多个连续的 / 为单个
// 打印当前处理的文件信息(包含原始和标准化路径)
log.info("正在处理第 {} 个文件 - 原始文件名: {}, 标准化路径: {}",
i + 1,
originalFilename,
normalizedPath);
try {
// 读取上传文件,提取文档内容
log.debug("开始解析文件内容: {}", normalizedPath);
TikaDocumentReader documentReader = new TikaDocumentReader(file.getResource());
List<Document> documents = documentReader.get();
// 对文档进行 Token 拆分
log.debug("开始拆分文件内容: {}", normalizedPath);
List<Document> documentSplitterList = tokenTextSplitter.apply(documents);
// ===== 元数据设置(保留完整标准化路径)=====
// 1. 原始文档设置元数据
documents.forEach(doc -> {
doc.getMetadata().put("knowledge", ragTag);
doc.getMetadata().put("path", normalizedPath); // 使用标准化路径
doc.getMetadata().put("original_filename", originalFilename); // 可选:额外保留原始文件名
});
// 2. 拆分后的文档设置元数据
documentSplitterList.forEach(doc -> {
doc.getMetadata().put("knowledge", ragTag);
doc.getMetadata().put("path", normalizedPath);
doc.getMetadata().put("original_filename", originalFilename);
});
// 存储拆分后的文档到 pgVectorStore
log.debug("开始存储文件到向量数据库: {}", normalizedPath);
pgVectorStore.accept(documentSplitterList);
log.info("文件处理完成: {}", normalizedPath);
} catch (Exception e) {
log.error("文件处理失败:{} - {}", normalizedPath, e.getMessage(), e);
// 可选:记录失败文件信息,但不中断整体流程
}
}
// 更新 Redis 标签列表(避免重复)
RList<String> elements = redissonClient.getList("ragTag");
if (!elements.contains(ragTag)) {
elements.add(ragTag);
log.info("新增知识库标签: {}", ragTag);
} else {
log.info("知识库标签已存在,无需新增: {}", ragTag);
}
log.info("上传知识库完成:{},共处理 {} 个文件", ragTag, files.size());
return Response.<String>builder().code("0000").info("调用成功").build();
}
/**
* 克隆并分析 Git 仓库:
* - 克隆指定仓库到本地
* - 遍历文件,使用 Tika 提取并拆分
* - 存储到 pgVectorStore 并更新 Redis 标签列表
* POST /api/v1/rag/analyze_git_repository
*/
@PostMapping("analyze_git_repository")
@Override
public Response<String> analyzeGitRepository(
@RequestParam("repoUrl") String repoUrl,
@RequestParam("userName") String userName,
@RequestParam("token") String token,
@RequestParam("ragTag") String ragTag) throws Exception {
String localPath = "./git-cloned-repo";
String repoProjectName = StringUtils.isNotBlank(ragTag) ?
ragTag :
extractProjectName(repoUrl);
log.info("克隆路径:{}", new File(localPath).getAbsolutePath());
// 1. 干净克隆
FileUtils.deleteDirectory(new File(localPath));
Git git = Git.cloneRepository()
.setURI(repoUrl)
.setDirectory(new File(localPath))
.setCredentialsProvider(new UsernamePasswordCredentialsProvider(userName, token))
.call();
// 2. 遍历并处理文件
Files.walkFileTree(Paths.get(localPath), new SimpleFileVisitor<>() {
@Override
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) {
// 跳过 .git 目录
if (".git".equals(dir.getFileName().toString())) {
return FileVisitResult.SKIP_SUBTREE;
}
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
log.info("解析并上传文件:{} -> {}", repoProjectName, file.getFileName());
try {
// 2.1 读取原始文档(不可变列表)
List<Document> raw = new TikaDocumentReader(new PathResource(file)).get();
// 2.2 复制为可变列表并过滤掉空内容
List<Document> docs = new ArrayList<>(raw);
docs.removeIf(d -> d.getText() == null || d.getText().trim().isEmpty());
if (docs.isEmpty()) {
return FileVisitResult.CONTINUE;
}
// 2.3 打标签
docs.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
// 2.4 拆分并打标签
List<Document> splits = tokenTextSplitter.apply(docs);
splits.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
// 2.5 写入向量库
pgVectorStore.accept(splits);
} catch (Exception e) {
// 无法读取、拆分或存储时记录错误并跳过
log.error("文件解析上传失败:{}", file.getFileName(), e);
}
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult visitFileFailed(Path file, IOException exc) {
// 文件访问失败时也不影响整体执行
log.warn("访问文件失败:{} - {}", file, exc.getMessage());
return FileVisitResult.CONTINUE;
}
});
// 3. 清理本地
FileUtils.deleteDirectory(new File(localPath));
git.close();
// 4. 更新 Redis 标签列表
RList<String> elements = redissonClient.getList("ragTag");
if (!elements.contains(repoProjectName)) {
elements.add(repoProjectName);
}
log.info("仓库分析并上传完成:{}", repoUrl);
return Response.<String>builder().code("0000").info("调用成功").build();
}
/**
* 测试接口
* @return
*/
@GetMapping("knowledge/_ping")
public Response<String> ragPing() {
log.info("RAG knowledge ping");
return Response.<String>builder().code("0000").info("ok").build();
}
/**
* 删除指定知识库(按 ragTag
* - 从 pgvector 向量库中删除所有 metadata.knowledge == ragTag 的文档
* - 从 Redis ragTag 列表中删除该标签
* DELETE /api/v1/rag/knowledge/{ragTag}
*/
@DeleteMapping("knowledge/{ragTag}")
public Response<String> deleteKnowledge(@PathVariable("ragTag") String ragTag) {
log.info("删除知识库开始:{}", ragTag);
try {
// 1) 删除向量库中对应知识库的所有向量1.0.0-M6支持过滤表达式删除
pgVectorStore.delete("knowledge == '" + ragTag + "'");
// 2) 从 Redis 标签列表移除
RList<String> elements = redissonClient.getList("ragTag");
elements.remove(ragTag);
log.info("删除知识库完成:{}", ragTag);
return Response.<String>builder().code("0000").info("删除成功").build();
} catch (Exception e) {
log.error("删除知识库失败:{}", ragTag, e);
return Response.<String>builder().code("9999").info("删除失败:" + e.getMessage()).build();
}
}
/**
* 从 Git 仓库 URL 提取项目名称(去除 .git 后缀)
*/
private String extractProjectName(String repoUrl) {
String[] parts = repoUrl.split("/");
String projectNameWithGit = parts[parts.length - 1];
return projectNameWithGit.replace(".git", "");
}
}