Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions bigtop-manager-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,8 @@
<artifactId>bigtop-manager-dao</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-reactor</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-qianfan</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-community-dashscope</artifactId>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</exclusion>
</exclusions>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

import org.springframework.stereotype.Component;

import dev.langchain4j.service.tool.ToolProvider;

import jakarta.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -70,8 +68,7 @@ private AIAssistant.Builder initializeBuilder(PlatformType platformType) {
}

@Override
public AIAssistant createWithPrompt(
AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt) {
public AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
Object id = generalAssistantConfig.getId();
Expand All @@ -81,25 +78,23 @@ public AIAssistant createWithPrompt(

AIAssistant.Builder builder = initializeBuilder(platformType);
builder.id(id)
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore(id))
.withConfig(generalAssistantConfig);

configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage());

return builder.build();
}

@Override
public AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider) {
public AIAssistant createForTest(AIAssistantConfig config, Object toolProvider) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
AIAssistant.Builder builder = initializeBuilder(platformType);

builder.id(null)
.memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);
.withConfig(generalAssistantConfig);

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.stereotype.Component;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;

@Component
Expand All @@ -37,11 +37,17 @@ public class ChatMemoryStoreProvider {
@Resource
private ChatMessageDao chatMessageDao;

public ChatMemoryStore createPersistentChatMemoryStore() {
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
public ChatMemory createPersistentChatMemoryStore(Object conversationId) {
PersistentChatMemoryStore repository =
new PersistentChatMemoryStore((Long) conversationId, chatThreadDao, chatMessageDao);
return MessageWindowChatMemory.builder()
.chatMemoryRepository(repository)
.build();
}

public ChatMemoryStore createInMemoryChatMemoryStore() {
return new InMemoryChatMemoryStore();
public ChatMemory createInMemoryChatMemoryStore() {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;

import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
Expand All @@ -38,41 +38,49 @@
import java.util.stream.Collectors;

@Slf4j
public class PersistentChatMemoryStore implements ChatMemoryStore {
public class PersistentChatMemoryStore implements ChatMemoryRepository {

private final List<ChatMessage> messagesInMemory = new ArrayList<>();
private final List<Message> messagesInMemory = new ArrayList<>();
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;
private final Long conversationId;

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
public PersistentChatMemoryStore(Long conversationId, ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.conversationId = conversationId;
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
private Message convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageType.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
return new AssistantMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageType.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageType.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
}

private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
private ChatMessagePO convertToChatMessagePO(Message message, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
if (message.getMessageType() == org.springframework.ai.chat.messages.MessageType.ASSISTANT) {
chatMessagePO.setSender(MessageType.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
if (aiMessage.text() == null) {
AssistantMessage assistantMessage = (AssistantMessage) message;
if (assistantMessage.getText() == null) {
return null;
}
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setMessage(assistantMessage.getText());
} else if (message.getMessageType() == org.springframework.ai.chat.messages.MessageType.USER) {
chatMessagePO.setSender(MessageType.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
UserMessage userMessage = (UserMessage) message;
chatMessagePO.setMessage(userMessage.getText());
} else if (message.getMessageType() == org.springframework.ai.chat.messages.MessageType.SYSTEM) {
chatMessagePO.setSender(MessageType.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) message;
chatMessagePO.setMessage(systemMessage.getText());
} else {
return null;
}
Expand All @@ -82,11 +90,11 @@ private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatT
return chatMessagePO;
}

private List<ChatMessage> sortMessages(List<ChatMessage> messages) {
List<ChatMessage> systemMessages = messages.stream()
private List<Message> sortMessages(List<Message> messages) {
List<Message> systemMessages = messages.stream()
.filter(message -> message instanceof SystemMessage)
.collect(Collectors.toList());
List<ChatMessage> otherMessages = messages.stream()
List<Message> otherMessages = messages.stream()
.filter(message -> !(message instanceof SystemMessage))
.toList();

Expand All @@ -95,9 +103,15 @@ private List<ChatMessage> sortMessages(List<ChatMessage> messages) {
}

@Override
public List<ChatMessage> getMessages(Object threadId) {
List<ChatMessagePO> chatMessages = chatMessageDao.findAllByThreadId((Long) threadId);
List<ChatMessage> allChatMessages = new ArrayList<>();
public List<String> findConversationIds() {
// Return the current conversation ID as a list
return List.of(String.valueOf(conversationId));
}

@Override
public List<Message> findByConversationId(String conversationId) {
List<ChatMessagePO> chatMessages = chatMessageDao.findAllByThreadId(this.conversationId);
List<Message> allChatMessages = new ArrayList<>();
if (!chatMessages.isEmpty()) {
allChatMessages.addAll(chatMessages.stream()
.map(this::convertToChatMessage)
Expand All @@ -111,20 +125,22 @@ public List<ChatMessage> getMessages(Object threadId) {
}

@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessage newMessage = messages.get(messages.size() - 1);
ChatMessagePO chatMessagePO = convertToChatMessagePO(newMessage, (Long) threadId);
if (chatMessagePO == null) {
messagesInMemory.add(newMessage);
return;
public void saveAll(String conversationId, List<Message> messages) {
for (Message message : messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(message, this.conversationId);
if (chatMessagePO == null) {
messagesInMemory.add(message);
continue;
}
chatMessageDao.save(chatMessagePO);
}
chatMessageDao.save(chatMessagePO);
}

@Override
public void deleteMessages(Object threadId) {
List<ChatMessagePO> chatMessagePOS = chatMessageDao.findAllByThreadId((Long) threadId);
public void deleteByConversationId(String conversationId) {
List<ChatMessagePO> chatMessagePOS = chatMessageDao.findAllByThreadId(this.conversationId);
chatMessagePOS.forEach(chatMessage -> chatMessage.setIsDeleted(true));
chatMessageDao.partialUpdateByIds(chatMessagePOS);
messagesInMemory.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;

import reactor.core.publisher.Flux;

public abstract class AbstractAIAssistant implements AIAssistant {
protected final AIAssistant.Service aiServices;
protected static final Integer MEMORY_LEN = 10;
protected final ChatMemory chatMemory;
protected final Object memoryId;

protected AbstractAIAssistant(ChatMemory chatMemory, AIAssistant.Service aiServices) {
protected AbstractAIAssistant(Object memoryId, ChatMemory chatMemory, AIAssistant.Service aiServices) {
this.memoryId = memoryId;
this.chatMemory = chatMemory;
this.aiServices = aiServices;
}
Expand All @@ -44,7 +46,7 @@ public boolean test() {

@Override
public Object getId() {
return chatMemory.id();
return memoryId;
}

@Override
Expand All @@ -60,19 +62,13 @@ public String ask(String chatMessage) {
public abstract static class Builder implements AIAssistant.Builder {
protected Object id;

protected ChatMemoryStore chatMemoryStore;
protected ChatMemory chatMemory;
protected AIAssistantConfig config;

protected ToolProvider toolProvider;
protected String systemPrompt;

public Builder() {}

public Builder withToolProvider(ToolProvider toolProvider) {
this.toolProvider = toolProvider;
return this;
}

public Builder withSystemPrompt(String systemPrompt) {
this.systemPrompt = systemPrompt;
return this;
Expand All @@ -88,19 +84,18 @@ public Builder id(Object id) {
return this;
}

public Builder memoryStore(ChatMemoryStore chatMemoryStore) {
this.chatMemoryStore = chatMemoryStore;
public Builder memoryStore(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
return this;
}

public MessageWindowChatMemory getChatMemory() {
MessageWindowChatMemory.Builder builder = MessageWindowChatMemory.builder()
.chatMemoryStore(chatMemoryStore)
.maxMessages(MEMORY_LEN);
if (id != null) {
builder.id(id);
public ChatMemory getChatMemory() {
if (chatMemory == null) {
chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(new InMemoryChatMemoryRepository())
.build();
}
return builder.build();
return chatMemory;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.StreamingChatModel;

import reactor.core.publisher.Flux;

public interface AIAssistant {
Expand Down Expand Up @@ -72,12 +71,10 @@ interface Service {
interface Builder {
Builder id(Object id);

Builder memoryStore(ChatMemoryStore memoryStore);
Builder memoryStore(ChatMemory memoryStore);

Builder withConfig(AIAssistantConfig configProvider);

Builder withToolProvider(ToolProvider toolProvider);

Builder withSystemPrompt(String systemPrompt);

AIAssistant build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;

import dev.langchain4j.service.tool.ToolProvider;

public interface AIAssistantFactory {

AIAssistant createWithPrompt(AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt);
AIAssistant createWithPrompt(AIAssistantConfig config, Object toolProvider, SystemPrompt systemPrompt);

AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider);
AIAssistant createForTest(AIAssistantConfig config, Object toolProvider);

default AIAssistant createAIService(AIAssistantConfig config, ToolProvider toolProvider) {
default AIAssistant createAIService(AIAssistantConfig config, Object toolProvider) {
return createWithPrompt(config, toolProvider, SystemPrompt.DEFAULT_PROMPT);
}
}
Loading
Loading