Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/main/java/com/example/spring/app/llm/LLMConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ public ChatClient chatClient(ChatClient.Builder chatClientBuilder, ChatMemory ch
.build();
}

@Bean
public ChatClient titleClient(ChatClient.Builder titleClientBuilder) {
return titleClientBuilder
.defaultSystem(
"""
You generate a short title for a chat conversation.

Rules:
- Output exactly ONE line (no quotes, no punctuation, no line breaks).
- Max 6 words.
- No lists, no categories.
- If the conversation is generic small talk (e.g., greetings, “how are you”, “what’s up”), output: "Small talk".
- Do not guess topics like weather unless explicitly discussed.
"""
)
.build();
}

@Bean
public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) {
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
Expand All @@ -59,7 +77,7 @@ public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) {

return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.maxMessages(20)
.build();
}
}
37 changes: 28 additions & 9 deletions src/main/java/com/example/spring/app/llm/LLMController.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,46 @@
import java.time.Instant;
import java.util.List;

import static com.example.spring.app.llm.LLMUtils.wrapUserInputWithConversationContext;
import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader;

@CrossOrigin
@RestController
@RequestMapping("/v1/chat")
public class LLMController {

private final ChatClient titleClient;
private final ChatClient chatClient;
private final UserConversationService userConversationService;
private final SpringAiChatMemoryService springAiChatMemoryService;

public LLMController(ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) {
public LLMController(ChatClient titleClient, ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) {
this.titleClient = titleClient;
this.chatClient = chatClient;
this.userConversationService = userConversationService;
this.springAiChatMemoryService = springAiChatMemoryService;
}

@PostMapping(value = "/conversation/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
@PostMapping(value = "/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ChatStreamResponseDTO> streamGeneration(@PathVariable String conversationId, @RequestBody LLMRequest request) {
String userId = extractUserIdFromHeader();
String userInput = wrapUserInputWithConversationContext(request.userInput());
boolean isNewConversation = conversationId.equals("new");
String conversationTitle = isNewConversation
? titleClient.prompt().user(userInput).call().content()
: null;

UserConversationModel conversation =
conversationId.equals("new")
? userConversationService.createNewConversationForUser(userId)
isNewConversation
? userConversationService.createNewConversationForUser(userId, conversationTitle)
: userConversationService.getUserConversation(conversationId, userId);

if (conversation == null) {
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
}

return chatClient.prompt()
.user(userSpec -> userSpec.text(request.userInput()))
.user(userSpec -> userSpec.text(userInput))
.advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, conversation.getConversationId()))
.stream()
.chatResponse()
Expand All @@ -55,7 +63,7 @@ public Flux<ChatStreamResponseDTO> streamGeneration(@PathVariable String convers


// TODO: Add pagination to this endpoint
@GetMapping("/conversation/history/{conversationId}")
@GetMapping("/history/{conversationId}")
public List<MessageDTO> getConversationHistory(@PathVariable String conversationId) {
String userId = extractUserIdFromHeader();
UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId);
Expand All @@ -70,20 +78,31 @@ public List<MessageDTO> getConversationHistory(@PathVariable String conversation
}

// Issue with open-api generator which generates ENUM, and values are the sames
@DeleteMapping("/conversation/delete/{conversationId}")
@DeleteMapping("/delete/{conversationId}")
public void deleteConversation(@PathVariable String conversationId) {
String userId = extractUserIdFromHeader();
springAiChatMemoryService.deleteAllByConversationId(conversationId);
userConversationService.deleteUserConversation(conversationId, userId);
}

@GetMapping("/conversation/all")
@GetMapping("/single/{conversationId}")
public ConversationDTO getSingleConversation(@PathVariable String conversationId) {
String userId = extractUserIdFromHeader();
UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId);
if (conversation == null) {
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
}

return new ConversationDTO(conversation.getConversationId(), conversation.getTitle());
}

@GetMapping("/all")
public List<ConversationDTO> getAllUserConversations() {
String userId = extractUserIdFromHeader();
List<UserConversationModel> conversations = userConversationService.getAllConversationsForUser(userId);

return conversations.stream()
.map(conv -> new ConversationDTO(conv.getConversationId()))
.map(conv -> new ConversationDTO(conv.getConversationId(), conv.getTitle()))
.toList();
}
}
25 changes: 25 additions & 0 deletions src/main/java/com/example/spring/app/llm/LLMUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.example.spring.app.llm;

public class LLMUtils {
public static String wrapUserInputWithConversationContext(String userInput) {
return """
Conversation:
%s
""".formatted(userInput);
}

/**
* Parse conversation title to limit to max 6 words.
* If more than 6 words, append "..." at the end.
*/
public static String parseConversationTitle(String conversationTitle) {
String parsedTitle = conversationTitle;
if (conversationTitle != null) {
String[] words = conversationTitle.split("\\s+");
if (words.length > 6) {
parsedTitle = String.join(" ", java.util.Arrays.copyOfRange(words, 0, 6)) + "...";
}
}
return parsedTitle;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.example.spring.app.llm.dto;

public record ConversationDTO(
String conversationId
//String title
String conversationId,
String title
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ public class UserConversationModel {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Integer id;
private String title;
private String conversationId;
private String userId;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import java.util.List;
import java.util.UUID;

import static com.example.spring.app.llm.LLMUtils.parseConversationTitle;

@Service
public class UserConversationService {
@Autowired
Expand All @@ -15,10 +17,15 @@ public UserConversationModel getUserConversation(String conversationId, String u
return userConversationRepository.findByConversationIdAndUserId(conversationId, userId);
}

public UserConversationModel createNewConversationForUser(String userId) {
public UserConversationModel createNewConversationForUser(String userId, String conversationTitle) {
UserConversationModel newConversation = new UserConversationModel();
newConversation.setUserId(userId);

String parsedTitle = parseConversationTitle(conversationTitle);
newConversation.setTitle(parsedTitle);

String conversationId = UUID.randomUUID().toString();

newConversation.setUserId(userId);
newConversation.setConversationId(conversationId);
return userConversationRepository.save(newConversation);
}
Expand Down