207 lines
8.2 KiB
Java
207 lines
8.2 KiB
Java
|
package edu.whut.trigger.http;
|
|||
|
|
|||
|
import edu.whut.api.IRAGService;
|
|||
|
import edu.whut.api.response.Response;
|
|||
|
import jakarta.annotation.Resource;
|
|||
|
import lombok.RequiredArgsConstructor;
|
|||
|
import lombok.extern.slf4j.Slf4j;
|
|||
|
import org.apache.commons.io.FileUtils;
|
|||
|
import org.eclipse.jgit.api.Git;
|
|||
|
import org.eclipse.jgit.transport.UsernamePasswordCredentialsProvider;
|
|||
|
import org.redisson.api.RList;
|
|||
|
import org.redisson.api.RedissonClient;
|
|||
|
import org.springframework.ai.document.Document;
|
|||
|
import org.springframework.ai.ollama.OllamaChatClient;
|
|||
|
import org.springframework.ai.reader.tika.TikaDocumentReader;
|
|||
|
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
|
|||
|
import org.springframework.ai.vectorstore.PgVectorStore;
|
|||
|
import org.springframework.ai.vectorstore.SimpleVectorStore;
|
|||
|
import org.springframework.core.io.PathResource;
|
|||
|
import org.springframework.http.MediaType;
|
|||
|
import org.springframework.web.bind.annotation.*;
|
|||
|
import org.springframework.web.multipart.MultipartFile;
|
|||
|
|
|||
|
import java.io.File;
|
|||
|
import java.io.IOException;
|
|||
|
import java.nio.file.*;
|
|||
|
import java.nio.file.attribute.BasicFileAttributes;
|
|||
|
import java.util.ArrayList;
|
|||
|
import java.util.List;
|
|||
|
|
|||
|
/**
|
|||
|
* RAG 服务控制器,实现 IRAGService 接口,提供知识库管理和检索相关的 HTTP 接口
|
|||
|
*/
|
|||
|
@Slf4j
|
|||
|
@RestController
|
|||
|
@RequestMapping("/api/v1/rag/")
|
|||
|
@CrossOrigin("*")
|
|||
|
@RequiredArgsConstructor
|
|||
|
public class RAGController implements IRAGService {
|
|||
|
|
|||
|
// Ollama 聊天客户端,用于后续可能的对话调用(此处暂无直接使用)
|
|||
|
private final OllamaChatClient ollamaChatClient;
|
|||
|
|
|||
|
// 文本拆分器,将长文档切分为合适大小的段落或 Token 块
|
|||
|
private final TokenTextSplitter tokenTextSplitter;
|
|||
|
|
|||
|
// 简易内存向量存储,用于快速测试或小规模存储
|
|||
|
private final SimpleVectorStore simpleVectorStore;
|
|||
|
|
|||
|
// PostgreSQL pgvector 存储,用于持久化和检索嵌入向量
|
|||
|
private final PgVectorStore pgVectorStore;
|
|||
|
|
|||
|
// Redisson 客户端,用于操作 Redis 列表存储 RAG 标签
|
|||
|
private final RedissonClient redissonClient;
|
|||
|
|
|||
|
/**
|
|||
|
* 查询所有已上传的 RAG 标签列表
|
|||
|
* GET /api/v1/rag/query_rag_tag_list
|
|||
|
*/
|
|||
|
@GetMapping("query_rag_tag_list")
|
|||
|
@Override
|
|||
|
public Response<List<String>> queryRagTagList() {
|
|||
|
// 从 Redis 列表获取所有标签
|
|||
|
RList<String> elements = redissonClient.getList("ragTag");
|
|||
|
return Response.<List<String>>builder()
|
|||
|
.code("0000")
|
|||
|
.info("调用成功")
|
|||
|
.data(elements)
|
|||
|
.build();
|
|||
|
}
|
|||
|
|
|||
|
/**
|
|||
|
* 上传文件到知识库:
|
|||
|
* - 使用 Tika 读取文档内容
|
|||
|
* - 进行文本切分并贴上 ragTag 元数据
|
|||
|
* - 存储到 pgVectorStore 并更新 Redis 标签列表
|
|||
|
* POST /api/v1/rag/file/upload
|
|||
|
*/
|
|||
|
@PostMapping(path = "file/upload", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
|
|||
|
@Override
|
|||
|
public Response<String> uploadFile(
|
|||
|
@RequestParam("ragTag") String ragTag,
|
|||
|
@RequestParam("file") List<MultipartFile> files) {
|
|||
|
log.info("上传知识库开始:{}", ragTag);
|
|||
|
for (MultipartFile file : files) {
|
|||
|
// 读取上传文件,提取文档内容
|
|||
|
TikaDocumentReader documentReader = new TikaDocumentReader(file.getResource());
|
|||
|
List<Document> documents = documentReader.get();
|
|||
|
// 对文档进行 Token 拆分
|
|||
|
List<Document> documentSplitterList = tokenTextSplitter.apply(documents);
|
|||
|
|
|||
|
// 为原文档和拆分文档设置 ragTag 元数据
|
|||
|
documents.forEach(doc -> doc.getMetadata().put("knowledge", ragTag));
|
|||
|
documentSplitterList.forEach(doc -> doc.getMetadata().put("knowledge", ragTag));
|
|||
|
|
|||
|
// 存储拆分后的文档到 pgVectorStore
|
|||
|
pgVectorStore.accept(documentSplitterList);
|
|||
|
|
|||
|
// 更新 Redis 标签列表,避免重复
|
|||
|
RList<String> elements = redissonClient.getList("ragTag");
|
|||
|
if (!elements.contains(ragTag)) {
|
|||
|
elements.add(ragTag);
|
|||
|
}
|
|||
|
}
|
|||
|
log.info("上传知识库完成:{}", ragTag);
|
|||
|
return Response.<String>builder().code("0000").info("调用成功").build();
|
|||
|
}
|
|||
|
|
|||
|
/**
|
|||
|
* 克隆并分析 Git 仓库:
|
|||
|
* - 克隆指定仓库到本地
|
|||
|
* - 遍历文件,使用 Tika 提取并拆分
|
|||
|
* - 存储到 pgVectorStore 并更新 Redis 标签列表
|
|||
|
* POST /api/v1/rag/analyze_git_repository
|
|||
|
*/
|
|||
|
@PostMapping("analyze_git_repository")
|
|||
|
@Override
|
|||
|
public Response<String> analyzeGitRepository(
|
|||
|
@RequestParam("repoUrl") String repoUrl,
|
|||
|
@RequestParam("userName") String userName,
|
|||
|
@RequestParam("token") String token) throws Exception {
|
|||
|
|
|||
|
String localPath = "./git-cloned-repo";
|
|||
|
String repoProjectName = extractProjectName(repoUrl);
|
|||
|
log.info("克隆路径:{}", new File(localPath).getAbsolutePath());
|
|||
|
|
|||
|
// 1. 干净克隆
|
|||
|
FileUtils.deleteDirectory(new File(localPath));
|
|||
|
Git git = Git.cloneRepository()
|
|||
|
.setURI(repoUrl)
|
|||
|
.setDirectory(new File(localPath))
|
|||
|
.setCredentialsProvider(new UsernamePasswordCredentialsProvider(userName, token))
|
|||
|
.call();
|
|||
|
|
|||
|
// 2. 遍历并处理文件
|
|||
|
Files.walkFileTree(Paths.get(localPath), new SimpleFileVisitor<>() {
|
|||
|
@Override
|
|||
|
public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) {
|
|||
|
// 跳过 .git 目录
|
|||
|
if (".git".equals(dir.getFileName().toString())) {
|
|||
|
return FileVisitResult.SKIP_SUBTREE;
|
|||
|
}
|
|||
|
return FileVisitResult.CONTINUE;
|
|||
|
}
|
|||
|
|
|||
|
@Override
|
|||
|
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
|
|||
|
log.info("解析并上传文件:{} -> {}", repoProjectName, file.getFileName());
|
|||
|
try {
|
|||
|
// 2.1 读取原始文档(不可变列表)
|
|||
|
List<Document> raw = new TikaDocumentReader(new PathResource(file)).get();
|
|||
|
|
|||
|
// 2.2 复制为可变列表并过滤掉空内容
|
|||
|
List<Document> docs = new ArrayList<>(raw);
|
|||
|
docs.removeIf(d -> d.getContent() == null || d.getContent().trim().isEmpty());
|
|||
|
if (docs.isEmpty()) {
|
|||
|
return FileVisitResult.CONTINUE;
|
|||
|
}
|
|||
|
|
|||
|
// 2.3 打标签
|
|||
|
docs.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
|
|||
|
|
|||
|
// 2.4 拆分并打标签
|
|||
|
List<Document> splits = tokenTextSplitter.apply(docs);
|
|||
|
splits.forEach(d -> d.getMetadata().put("knowledge", repoProjectName));
|
|||
|
|
|||
|
// 2.5 写入向量库
|
|||
|
pgVectorStore.accept(splits);
|
|||
|
} catch (Exception e) {
|
|||
|
// 无法读取、拆分或存储时记录错误并跳过
|
|||
|
log.error("文件解析上传失败:{}", file.getFileName(), e);
|
|||
|
}
|
|||
|
return FileVisitResult.CONTINUE;
|
|||
|
}
|
|||
|
|
|||
|
@Override
|
|||
|
public FileVisitResult visitFileFailed(Path file, IOException exc) {
|
|||
|
// 文件访问失败时也不影响整体执行
|
|||
|
log.warn("访问文件失败:{} - {}", file, exc.getMessage());
|
|||
|
return FileVisitResult.CONTINUE;
|
|||
|
}
|
|||
|
});
|
|||
|
|
|||
|
// 3. 清理本地
|
|||
|
FileUtils.deleteDirectory(new File(localPath));
|
|||
|
git.close();
|
|||
|
|
|||
|
// 4. 更新 Redis 标签列表
|
|||
|
RList<String> elements = redissonClient.getList("ragTag");
|
|||
|
if (!elements.contains(repoProjectName)) {
|
|||
|
elements.add(repoProjectName);
|
|||
|
}
|
|||
|
|
|||
|
log.info("仓库分析并上传完成:{}", repoUrl);
|
|||
|
return Response.<String>builder().code("0000").info("调用成功").build();
|
|||
|
}
|
|||
|
|
|||
|
/**
|
|||
|
* 从 Git 仓库 URL 提取项目名称(去除 .git 后缀)
|
|||
|
*/
|
|||
|
private String extractProjectName(String repoUrl) {
|
|||
|
String[] parts = repoUrl.split("/");
|
|||
|
String projectNameWithGit = parts[parts.length - 1];
|
|||
|
return projectNameWithGit.replace(".git", "");
|
|||
|
}
|
|||
|
}
|