367 lines
15 KiB
Java
367 lines
15 KiB
Java
package edu.whut.trigger.http;
|
||
|
||
import edu.whut.api.IRAGService;
|
||
import edu.whut.api.response.Response;
|
||
import lombok.RequiredArgsConstructor;
|
||
import lombok.extern.slf4j.Slf4j;
|
||
import org.apache.commons.io.FileUtils;
|
||
import org.apache.commons.lang3.StringUtils;
|
||
import org.eclipse.jgit.api.Git;
|
||
import org.eclipse.jgit.transport.UsernamePasswordCredentialsProvider;
|
||
import org.redisson.api.RList;
|
||
import org.redisson.api.RedissonClient;
|
||
import org.springframework.ai.document.Document;
|
||
import org.springframework.ai.ollama.OllamaChatModel;
|
||
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
||
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
||
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
||
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
|
||
import org.springframework.core.io.PathResource;
|
||
import org.springframework.http.MediaType;
|
||
import org.springframework.jdbc.core.JdbcTemplate;
|
||
import org.springframework.web.bind.annotation.*;
|
||
import org.springframework.web.multipart.MultipartFile;
|
||
|
||
import java.io.File;
|
||
import java.io.IOException;
|
||
import java.nio.file.*;
|
||
import java.nio.file.attribute.BasicFileAttributes;
|
||
import java.util.ArrayList;
|
||
import java.util.Collections;
|
||
import java.util.List;
|
||
import java.util.Objects;
|
||
import java.util.stream.Collectors;
|
||
|
||
/**
|
||
* RAG 服务控制器,实现 IRAGService 接口,提供知识库管理和检索相关的 HTTP 接口
|
||
*/
|
||
@Slf4j
|
||
@RestController
|
||
@RequestMapping("/api/v1/rag/")
|
||
@CrossOrigin("*")
|
||
@RequiredArgsConstructor
|
||
public class RAGController implements IRAGService {
|
||
|
||
// Ollama 对话模型(此处暂无直接使用,但保留注入以便扩展)
|
||
private final OllamaChatModel ollamaChatModel;
|
||
|
||
// 文本拆分器,将长文档切分为合适大小的段落或 Token 块
|
||
private final TokenTextSplitter tokenTextSplitter;
|
||
|
||
// 简易内存向量存储,用于快速测试或小规模存储
|
||
private final SimpleVectorStore simpleVectorStore;
|
||
|
||
// PostgreSQL pgvector 存储,用于持久化和检索嵌入向量
|
||
private final PgVectorStore pgVectorStore;
|
||
|
||
// Redisson 客户端,用于操作 Redis 列表存储 RAG 标签
|
||
private final RedissonClient redissonClient;
|
||
|
||
private final JdbcTemplate jdbcTemplate; // 注入Spring JDBC
|
||
/**
|
||
* 查询所有已上传的 RAG 标签列表
|
||
* GET /api/v1/rag/query_rag_tag_list
|
||
*/
|
||
@GetMapping("query_rag_tag_list")
|
||
@Override
|
||
public Response<List<String>> queryRagTagList() {
|
||
// 从 Redis 列表获取所有标签
|
||
RList<String> elements = redissonClient.getList("ragTag");
|
||
// 读一个快照,便于安全日志与返回
|
||
List<String> tags;
|
||
try {
|
||
tags = elements.readAll(); // Redisson 提供的批量读取
|
||
} catch (Exception e) {
|
||
// 兜底:某些客户端/版本没有 readAll 时使用迭代
|
||
log.warn("读取 Redis ragTag 列表使用 readAll 失败,改用迭代读取。", e);
|
||
tags = new ArrayList<>();
|
||
for (String s : elements) {
|
||
tags.add(s);
|
||
}
|
||
}
|
||
// 打印查询到的标签(数量 + 内容)
|
||
log.info("查询 RAG 标签列表,数量:{},内容:{}", tags.size(), tags);
|
||
return Response.<List<String>>builder()
|
||
.code("0000")
|
||
.info("调用成功")
|
||
.data(elements)
|
||
.build();
|
||
}
|
||
|
||
/**
|
||
* 查询指定知识库下的所有文件名
|
||
* GET /api/v1/rag/knowledge/{ragTag}/files
|
||
*/
|
||
@GetMapping("knowledge/files/{ragTag}")
|
||
public Response<List<String>> getKnowledgeFiles(@PathVariable String ragTag) {
|
||
log.info("查询知识库文件列表开始:{}", ragTag);
|
||
|
||
try {
|
||
// 使用原生SQL查询(假设表名为 vector_store)
|
||
String sql = """
|
||
SELECT DISTINCT metadata->>'path' as file_path
|
||
FROM vector_store
|
||
WHERE metadata->>'knowledge' = ?
|
||
ORDER BY metadata->>'path'
|
||
""";
|
||
|
||
// 执行查询并获取结果
|
||
List<String> filePaths = jdbcTemplate.queryForList(
|
||
sql,
|
||
new Object[]{ragTag},
|
||
String.class
|
||
);
|
||
|
||
log.info("查询知识库文件列表完成:{},共 {} 个文件", ragTag, filePaths.size());
|
||
return Response.<List<String>>builder()
|
||
.code("0000")
|
||
.info("查询成功")
|
||
.data(filePaths)
|
||
.build();
|
||
|
||
} catch (Exception e) {
|
||
log.error("查询知识库文件列表失败:{}", ragTag, e);
|
||
return Response.<List<String>>builder()
|
||
.code("9999")
|
||
.info("查询失败:" + e.getMessage())
|
||
.data(Collections.emptyList())
|
||
.build();
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 注意:可以追加文件到相同的知识库中!
|
||
* 上传文件到知识库:
|
||
* - 使用 Tika 读取文档内容
|
||
* - 进行文本切分并贴上 ragTag 元数据
|
||
* - 存储到 pgVectorStore 并更新 Redis 标签列表
|
||
* POST /api/v1/rag/file/upload
|
||
*/
|
||
/**
|
||
* 注意:可以追加文件到相同的知识库中!
|
||
* 上传文件到知识库:
|
||
* - 使用 Tika 读取文档内容
|
||
* - 进行文本切分并贴上 ragTag 元数据
|
||
* - 存储到 pgVectorStore 并更新 Redis 标签列表
|
||
* POST /api/v1/rag/file/upload
|
||
*/
|
||
@PostMapping(path = "file/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
|
||
@Override
|
||
public Response<String> uploadFile(
|
||
@RequestParam("ragTag") String ragTag,
|
||
@RequestParam("file") List<MultipartFile> files,
|
||
@RequestParam(value = "filePath", required = false) List<String> filePaths) {
|
||
|
||
log.info("上传知识库开始:{}", ragTag);
|
||
log.info("待处理文件数量:{},传入路径数量:{}", files.size(), filePaths != null ? filePaths.size() : 0);
|
||
|
||
// 使用带索引的 for 循环,保证与 filePath 一一对应
|
||
for (int i = 0; i < files.size(); i++) {
|
||
MultipartFile file = files.get(i);
|
||
String originalFilename = file.getOriginalFilename();
|
||
|
||
// ===== 路径标准化处理 =====
|
||
// 1. 获取原始路径(优先使用前端传入的路径,否则使用文件名)
|
||
String rawPath = (filePaths != null && i < filePaths.size() && filePaths.get(i) != null && !filePaths.get(i).isBlank())
|
||
? filePaths.get(i)
|
||
: (originalFilename != null ? originalFilename : "");
|
||
|
||
// 2. 标准化路径格式(统一用 / 分隔符,去除开头多余的 ./ 或 /)
|
||
String normalizedPath = rawPath
|
||
.replace("\\", "/") // 统一使用正斜杠
|
||
.replaceAll("^[/.]+", "") // 去除开头的 ./ 或 /
|
||
.replaceAll("/+", "/"); // 替换多个连续的 / 为单个
|
||
|
||
// 打印当前处理的文件信息(包含原始和标准化路径)
|
||
log.info("正在处理第 {} 个文件 - 原始文件名: {}, 标准化路径: {}",
|
||
i + 1,
|
||
originalFilename,
|
||
normalizedPath);
|
||
|
||
try {
|
||
// 读取上传文件,提取文档内容
|
||
log.debug("开始解析文件内容: {}", normalizedPath);
|
||
TikaDocumentReader documentReader = new TikaDocumentReader(file.getResource());
|
||
List<Document> documents = documentReader.get();
|
||
|
||
// 对文档进行 Token 拆分
|
||
log.debug("开始拆分文件内容: {}", normalizedPath);
|
||
List<Document> documentSplitterList = tokenTextSplitter.apply(documents);
|
||
|
||
// ===== 元数据设置(保留完整标准化路径)=====
|
||
// 1. 原始文档设置元数据
|
||
documents.forEach(doc -> {
|
||
doc.getMetadata().put("knowledge", ragTag);
|
||
doc.getMetadata().put("path", normalizedPath); // 使用标准化路径
|
||
doc.getMetadata().put("original_filename", originalFilename); // 可选:额外保留原始文件名
|
||
});
|
||
|
||
// 2. 拆分后的文档设置元数据
|
||
documentSplitterList.forEach(doc -> {
|
||
doc.getMetadata().put("knowledge", ragTag);
|
||
doc.getMetadata().put("path", normalizedPath);
|
||
doc.getMetadata().put("original_filename", originalFilename);
|
||
});
|
||
|
||
// 存储拆分后的文档到 pgVectorStore
|
||
log.debug("开始存储文件到向量数据库: {}", normalizedPath);
|
||
pgVectorStore.accept(documentSplitterList);
|
||
log.info("文件处理完成: {}", normalizedPath);
|
||
|
||
} catch (Exception e) {
|
||
log.error("文件处理失败:{} - {}", normalizedPath, e.getMessage(), e);
|
||
// 可选:记录失败文件信息,但不中断整体流程
|
||
}
|
||
}
|
||
|
||
// 更新 Redis 标签列表(避免重复)
|
||
RList<String> elements = redissonClient.getList("ragTag");
|
||
if (!elements.contains(ragTag)) {
|
||
elements.add(ragTag);
|
||
log.info("新增知识库标签: {}", ragTag);
|
||
} else {
|
||
log.info("知识库标签已存在,无需新增: {}", ragTag);
|
||
}
|
||
|
||
log.info("上传知识库完成:{},共处理 {} 个文件", ragTag, files.size());
|
||
return Response.<String>builder().code("0000").info("调用成功").build();
|
||
}
|
||
|
||
/**
|
||
* 克隆并分析 Git 仓库:
|
||
* - 克隆指定仓库到本地
|
||
* - 遍历文件,使用 Tika 提取并拆分
|
||
* - 存储到 pgVectorStore 并更新 Redis 标签列表
|
||
* POST /api/v1/rag/analyze_git_repository
|
||
*/
|
||
@PostMapping("analyze_git_repository")
|
||
@Override
|
||
public Response<String> analyzeGitRepository(
|
||
@RequestParam("repoUrl") String repoUrl,
|
||
@RequestParam("userName") String userName,
|
||
@RequestParam("token") String token,
|
||
@RequestParam("ragTag") String ragTag) throws Exception {
|
||
|
||
String localPath = "./git-cloned-repo";
|
||
String repoProjectName = StringUtils.isNotBlank(ragTag) ?
|
||
ragTag :
|
||
extractProjectName(repoUrl);
|
||
log.info("克隆路径:{}", new File(localPath).getAbsolutePath());
|
||
|
||
// 1. 干净克隆
|
||
FileUtils.deleteDirectory(new File(localPath));
|
||
Git git = Git.cloneRepository()
|
||
.setURI(repoUrl)
|
||
.setDirectory(new File(localPath))
|
||
.setCredentialsProvider(new UsernamePasswordCredentialsProvider(userName, token))
|
||
.call();
|
||
|
||
// 2. 遍历并处理文件
|
||
Files.walkFileTree(Paths.get(localPath), new SimpleFileVisitor<>() {
|
||
@Override
|
||
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) {
|
||
// 跳过 .git 目录
|
||
if (".git".equals(dir.getFileName().toString())) {
|
||
return FileVisitResult.SKIP_SUBTREE;
|
||
}
|
||
return FileVisitResult.CONTINUE;
|
||
}
|
||
|
||
@Override
|
||
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
|
||
log.info("解析并上传文件:{} -> {}", repoProjectName, file.getFileName());
|
||
try {
|
||
// 2.1 读取原始文档(不可变列表)
|
||
List<Document> raw = new TikaDocumentReader(new PathResource(file)).get();
|
||
|
||
// 2.2 复制为可变列表并过滤掉空内容
|
||
List<Document> docs = new ArrayList<>(raw);
|
||
docs.removeIf(d -> d.getText() == null || d.getText().trim().isEmpty());
|
||
if (docs.isEmpty()) {
|
||
return FileVisitResult.CONTINUE;
|
||
}
|
||
|
||
// 2.3 打标签
|
||
docs.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
|
||
|
||
// 2.4 拆分并打标签
|
||
List<Document> splits = tokenTextSplitter.apply(docs);
|
||
splits.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
|
||
|
||
// 2.5 写入向量库
|
||
pgVectorStore.accept(splits);
|
||
} catch (Exception e) {
|
||
// 无法读取、拆分或存储时记录错误并跳过
|
||
log.error("文件解析上传失败:{}", file.getFileName(), e);
|
||
}
|
||
return FileVisitResult.CONTINUE;
|
||
}
|
||
|
||
@Override
|
||
public FileVisitResult visitFileFailed(Path file, IOException exc) {
|
||
// 文件访问失败时也不影响整体执行
|
||
log.warn("访问文件失败:{} - {}", file, exc.getMessage());
|
||
return FileVisitResult.CONTINUE;
|
||
}
|
||
});
|
||
|
||
// 3. 清理本地
|
||
FileUtils.deleteDirectory(new File(localPath));
|
||
git.close();
|
||
|
||
// 4. 更新 Redis 标签列表
|
||
RList<String> elements = redissonClient.getList("ragTag");
|
||
if (!elements.contains(repoProjectName)) {
|
||
elements.add(repoProjectName);
|
||
}
|
||
|
||
log.info("仓库分析并上传完成:{}", repoUrl);
|
||
return Response.<String>builder().code("0000").info("调用成功").build();
|
||
}
|
||
|
||
/**
|
||
* 测试接口
|
||
* @return
|
||
*/
|
||
@GetMapping("knowledge/_ping")
|
||
public Response<String> ragPing() {
|
||
log.info("RAG knowledge ping");
|
||
return Response.<String>builder().code("0000").info("ok").build();
|
||
}
|
||
|
||
|
||
/**
|
||
* 删除指定知识库(按 ragTag):
|
||
* - 从 pgvector 向量库中删除所有 metadata.knowledge == ragTag 的文档
|
||
* - 从 Redis ragTag 列表中删除该标签
|
||
* DELETE /api/v1/rag/knowledge/{ragTag}
|
||
*/
|
||
@DeleteMapping("knowledge/{ragTag}")
|
||
public Response<String> deleteKnowledge(@PathVariable("ragTag") String ragTag) {
|
||
log.info("删除知识库开始:{}", ragTag);
|
||
try {
|
||
// 1) 删除向量库中对应知识库的所有向量(1.0.0-M6:支持过滤表达式删除)
|
||
pgVectorStore.delete("knowledge == '" + ragTag + "'");
|
||
|
||
// 2) 从 Redis 标签列表移除
|
||
RList<String> elements = redissonClient.getList("ragTag");
|
||
elements.remove(ragTag);
|
||
|
||
log.info("删除知识库完成:{}", ragTag);
|
||
return Response.<String>builder().code("0000").info("删除成功").build();
|
||
} catch (Exception e) {
|
||
log.error("删除知识库失败:{}", ragTag, e);
|
||
return Response.<String>builder().code("9999").info("删除失败:" + e.getMessage()).build();
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 从 Git 仓库 URL 提取项目名称(去除 .git 后缀)
|
||
*/
|
||
private String extractProjectName(String repoUrl) {
|
||
String[] parts = repoUrl.split("/");
|
||
String projectNameWithGit = parts[parts.length - 1];
|
||
return projectNameWithGit.replace(".git", "");
|
||
}
|
||
}
|