diff --git a/src/main/java/com/example/spring/app/llm/LLMConfig.java b/src/main/java/com/example/spring/app/llm/LLMConfig.java index 6c67f00..6843f9e 100644 --- a/src/main/java/com/example/spring/app/llm/LLMConfig.java +++ b/src/main/java/com/example/spring/app/llm/LLMConfig.java @@ -50,6 +50,24 @@ public ChatClient chatClient(ChatClient.Builder chatClientBuilder, ChatMemory ch .build(); } + @Bean + public ChatClient titleClient(ChatClient.Builder titleClientBuilder) { + return titleClientBuilder + .defaultSystem( + """ + You generate a short title for a chat conversation. + + Rules: + - Output exactly ONE line (no quotes, no punctuation, no line breaks). + - Max 6 words. + - No lists, no categories. + - If the conversation is generic small talk (e.g., greetings, “how are you”, “what’s up”), output: "Small talk". + - Do not guess topics like weather unless explicitly discussed. + """ + ) + .build(); + } + @Bean public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) { ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder() @@ -59,7 +77,7 @@ public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) { return MessageWindowChatMemory.builder() .chatMemoryRepository(chatMemoryRepository) - .maxMessages(10) + .maxMessages(20) .build(); } } diff --git a/src/main/java/com/example/spring/app/llm/LLMController.java b/src/main/java/com/example/spring/app/llm/LLMController.java index 45bc76b..ebb38f5 100644 --- a/src/main/java/com/example/spring/app/llm/LLMController.java +++ b/src/main/java/com/example/spring/app/llm/LLMController.java @@ -15,6 +15,7 @@ import java.time.Instant; import java.util.List; +import static com.example.spring.app.llm.LLMUtils.wrapUserInputWithConversationContext; import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader; @CrossOrigin @@ -22,23 +23,30 @@ @RequestMapping("/v1/chat") public class LLMController { + private final ChatClient titleClient; private final ChatClient chatClient; private final UserConversationService userConversationService; private final SpringAiChatMemoryService springAiChatMemoryService; - public LLMController(ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) { + public LLMController(ChatClient titleClient, ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) { + this.titleClient = titleClient; this.chatClient = chatClient; this.userConversationService = userConversationService; this.springAiChatMemoryService = springAiChatMemoryService; } - @PostMapping(value = "/conversation/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + @PostMapping(value = "/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public Flux streamGeneration(@PathVariable String conversationId, @RequestBody LLMRequest request) { String userId = extractUserIdFromHeader(); + String userInput = wrapUserInputWithConversationContext(request.userInput()); + boolean isNewConversation = conversationId.equals("new"); + String conversationTitle = isNewConversation + ? titleClient.prompt().user(userInput).call().content() + : null; UserConversationModel conversation = - conversationId.equals("new") - ? userConversationService.createNewConversationForUser(userId) + isNewConversation + ? userConversationService.createNewConversationForUser(userId, conversationTitle) : userConversationService.getUserConversation(conversationId, userId); if (conversation == null) { @@ -46,7 +54,7 @@ public Flux streamGeneration(@PathVariable String convers } return chatClient.prompt() - .user(userSpec -> userSpec.text(request.userInput())) + .user(userSpec -> userSpec.text(userInput)) .advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, conversation.getConversationId())) .stream() .chatResponse() @@ -55,7 +63,7 @@ public Flux streamGeneration(@PathVariable String convers // TODO: Add pagination to this endpoint - @GetMapping("/conversation/history/{conversationId}") + @GetMapping("/history/{conversationId}") public List getConversationHistory(@PathVariable String conversationId) { String userId = extractUserIdFromHeader(); UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId); @@ -70,20 +78,31 @@ public List getConversationHistory(@PathVariable String conversation } // Issue with open-api generator which generates ENUM, and values are the sames - @DeleteMapping("/conversation/delete/{conversationId}") + @DeleteMapping("/delete/{conversationId}") public void deleteConversation(@PathVariable String conversationId) { String userId = extractUserIdFromHeader(); springAiChatMemoryService.deleteAllByConversationId(conversationId); userConversationService.deleteUserConversation(conversationId, userId); } - @GetMapping("/conversation/all") + @GetMapping("/single/{conversationId}") + public ConversationDTO getSingleConversation(@PathVariable String conversationId) { + String userId = extractUserIdFromHeader(); + UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId); + if (conversation == null) { + throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID."); + } + + return new ConversationDTO(conversation.getConversationId(), conversation.getTitle()); + } + + @GetMapping("/all") public List getAllUserConversations() { String userId = extractUserIdFromHeader(); List conversations = userConversationService.getAllConversationsForUser(userId); return conversations.stream() - .map(conv -> new ConversationDTO(conv.getConversationId())) + .map(conv -> new ConversationDTO(conv.getConversationId(), conv.getTitle())) .toList(); } } diff --git a/src/main/java/com/example/spring/app/llm/LLMUtils.java b/src/main/java/com/example/spring/app/llm/LLMUtils.java new file mode 100644 index 0000000..dc68b0d --- /dev/null +++ b/src/main/java/com/example/spring/app/llm/LLMUtils.java @@ -0,0 +1,25 @@ +package com.example.spring.app.llm; + +public class LLMUtils { + public static String wrapUserInputWithConversationContext(String userInput) { + return """ + Conversation: + %s + """.formatted(userInput); + } + + /** + * Parse conversation title to limit to max 6 words. + * If more than 6 words, append "..." at the end. + */ + public static String parseConversationTitle(String conversationTitle) { + String parsedTitle = conversationTitle; + if (conversationTitle != null) { + String[] words = conversationTitle.split("\\s+"); + if (words.length > 6) { + parsedTitle = String.join(" ", java.util.Arrays.copyOfRange(words, 0, 6)) + "..."; + } + } + return parsedTitle; + } +} diff --git a/src/main/java/com/example/spring/app/llm/dto/ConversationDTO.java b/src/main/java/com/example/spring/app/llm/dto/ConversationDTO.java index 3c36c8a..7ae467d 100644 --- a/src/main/java/com/example/spring/app/llm/dto/ConversationDTO.java +++ b/src/main/java/com/example/spring/app/llm/dto/ConversationDTO.java @@ -1,6 +1,6 @@ package com.example.spring.app.llm.dto; public record ConversationDTO( - String conversationId - //String title + String conversationId, + String title ) {} diff --git a/src/main/java/com/example/spring/app/llm/userConversation/UserConversationModel.java b/src/main/java/com/example/spring/app/llm/userConversation/UserConversationModel.java index 7f5fae1..a4bf864 100644 --- a/src/main/java/com/example/spring/app/llm/userConversation/UserConversationModel.java +++ b/src/main/java/com/example/spring/app/llm/userConversation/UserConversationModel.java @@ -12,6 +12,7 @@ public class UserConversationModel { @Id @GeneratedValue(strategy = GenerationType.IDENTITY) private Integer id; + private String title; private String conversationId; private String userId; } diff --git a/src/main/java/com/example/spring/app/llm/userConversation/UserConversationService.java b/src/main/java/com/example/spring/app/llm/userConversation/UserConversationService.java index ed147c7..87de83b 100644 --- a/src/main/java/com/example/spring/app/llm/userConversation/UserConversationService.java +++ b/src/main/java/com/example/spring/app/llm/userConversation/UserConversationService.java @@ -6,6 +6,8 @@ import java.util.List; import java.util.UUID; +import static com.example.spring.app.llm.LLMUtils.parseConversationTitle; + @Service public class UserConversationService { @Autowired @@ -15,10 +17,15 @@ public UserConversationModel getUserConversation(String conversationId, String u return userConversationRepository.findByConversationIdAndUserId(conversationId, userId); } - public UserConversationModel createNewConversationForUser(String userId) { + public UserConversationModel createNewConversationForUser(String userId, String conversationTitle) { UserConversationModel newConversation = new UserConversationModel(); - newConversation.setUserId(userId); + + String parsedTitle = parseConversationTitle(conversationTitle); + newConversation.setTitle(parsedTitle); + String conversationId = UUID.randomUUID().toString(); + + newConversation.setUserId(userId); newConversation.setConversationId(conversationId); return userConversationRepository.save(newConversation); }