Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-oauth2-resource-server'
implementation 'org.springframework.boot:spring-boot-starter-web'
implementation platform("org.springframework.ai:spring-ai-bom:1.0.0")
implementation 'org.springframework.ai:spring-ai-starter-model-openai'
implementation 'org.springframework.boot:spring-boot-starter-security'
implementation 'org.springdoc:springdoc-openapi-starter-webmvc-ui:2.6.0'
implementation 'org.springframework.boot:spring-boot-starter-validation'

// Spring AI
implementation platform("org.springframework.ai:spring-ai-bom:1.1.2")
implementation 'org.springframework.ai:spring-ai-starter-model-chat-memory-repository-jdbc'
implementation 'org.springframework.ai:spring-ai-starter-model-openai'

// Stats
implementation 'org.springframework.boot:spring-boot-starter-actuator'
implementation 'io.micrometer:micrometer-registry-prometheus'
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/example/spring/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
@EnableCaching
@EnableConfigurationProperties(EnvConfig.class)
public class Application {

public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.stream.Collectors;

import static com.example.spring.common.utils.JwtUtil.extractUserIdFromToken;
import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader;

@CrossOrigin
@RestController
Expand All @@ -32,7 +32,7 @@ public class CompanyController {
// Example: http://localhost:8080/api/v1/company/get-by-id/123
@GetMapping("/get-by-id/{id}")
public CompanyDtoWithStatusDTO getCompanyById(@PathVariable("id") Integer id) {
String userId = extractUserIdFromToken();
String userId = extractUserIdFromHeader();
CompanyDTO companyDto = companyService.getCompanyById(id).toCompanyDTO();
UserCompanyStatusModel userCompanyStatus = userCompanyStatusService
.getOneUserCompanyStatusByUserIdAndCompanyId(userId, id);
Expand All @@ -45,7 +45,7 @@ public CompanyDtoWithStatusDTO getCompanyById(@PathVariable("id") Integer id) {
public Page<CompanyDtoWithStatusDTO> getCompaniesSeenByUser(@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "10") int size) {
Pageable pageable = PageRequest.of(page, size);
String userId = extractUserIdFromToken();
String userId = extractUserIdFromHeader();
Page<CompanyModel> companies = companyService.getCompaniesSeenByUser(userId, pageable);

List<UserCompanyStatusModel> userCompanyStatuses = userCompanyStatusService
Expand All @@ -70,7 +70,7 @@ public Page<CompanyDetails> searchCompaniesByName(@RequestParam("companyName") S
@PostMapping("/filter-by-parameters")
public Page<CompanyDtoWithStatusDTO> getCompaniesByFilters(
@RequestBody(required = false) CompanyFilterRequest filterRequest) {
String userId = extractUserIdFromToken();
String userId = extractUserIdFromHeader();
Pageable pageable = PageRequest.of(filterRequest.getPage(), filterRequest.getSize());

Page<CompanyModel> companies = companyService.findCompaniesByFilters(
Expand Down Expand Up @@ -100,7 +100,7 @@ public Page<CompanyDtoWithStatusDTO> getCompaniesByFilters(
public Page<CompanyDtoWithStatusDTO> getRandomUnseenCompanies(@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "10") int size) {
Pageable pageable = PageRequest.of(page, size);
String userId = extractUserIdFromToken();
String userId = extractUserIdFromHeader();
Page<CompanyModel> companies = companyService.findRandomUnseenCompanies(userId, pageable);
List<UserCompanyStatusModel> userCompanyStatuses = userCompanyStatusService
.getMultipleUserCompanyStatusByUserIdAndCompanyIds(userId, companies.getContent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ public class LeaderController {

// Example: http://localhost:8080/api/v1/leader/get-by-id/123
@GetMapping("/get-by-id/{id}")
public LeaderModel getLeaderById(@PathVariable("id") Integer id) {
public LeaderModel getLeaderById(@PathVariable Integer id) {
return leaderService.getLeaderById(id);
}

// Example: http://localhost:8080/api/v1/leader/get-by-siren?siren=exemple
@GetMapping("/get-by-siren/{siren}")
public List<LeaderModel> getLeaderBySiren(@PathVariable("siren") String siren) {
public List<LeaderModel> getLeaderBySiren(@PathVariable String siren) {
return leaderService.getLeadersBySirens(siren);
}

Expand Down
10 changes: 0 additions & 10 deletions src/main/java/com/example/spring/app/llm/LLMAnswerDTO.java

This file was deleted.

42 changes: 42 additions & 0 deletions src/main/java/com/example/spring/app/llm/LLMConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.example.spring.app.llm;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.jdbc.PostgresChatMemoryRepositoryDialect;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

@Configuration
public class LLMConfig {
@Bean
public ChatClient chatClient(ChatClient.Builder chatClientBuilder, ChatMemory chatMemory) {
return chatClientBuilder
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
.defaultSystem(
"You are a helpful assistant which is named Pierre. You are developed by the goat Mathieu." +
"You always answer concisely and clearly. Don't be verbose." +
"Do not translate into another language unless explicitly asked. " +
"Very important: Always respond in Markdown." +
"Very important: Use a Marseillais accent when speaking french."
)
.build();
}

@Bean
public ChatMemory jdbcChatMemory(JdbcTemplate jdbcTemplate) {
ChatMemoryRepository chatMemoryRepository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(new PostgresChatMemoryRepositoryDialect())
.build();

return MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
.maxMessages(10)
.build();
}
}
78 changes: 60 additions & 18 deletions src/main/java/com/example/spring/app/llm/LLMController.java
Original file line number Diff line number Diff line change
@@ -1,39 +1,81 @@
package com.example.spring.app.llm;

import com.example.spring.app.llm.dto.ChatStreamResponseDTO;
import com.example.spring.app.llm.dto.ConversationDTO;
import com.example.spring.app.llm.dto.MessageDTO;
import com.example.spring.app.llm.springAiChatMemory.SpringAiChatMemoryService;
import com.example.spring.app.llm.userConversation.UserConversationModel;
import com.example.spring.app.llm.userConversation.UserConversationService;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;

import java.time.Instant;
import java.util.List;

import static com.example.spring.common.utils.JwtUtil.extractUserIdFromHeader;

@CrossOrigin
@RestController
@RequestMapping("/v1")
@RequestMapping("/v1/chat")
public class LLMController {

private final ChatClient chatClient;
private final UserConversationService userConversationService;
private final SpringAiChatMemoryService springAiChatMemoryService;

public LLMController(ChatClient.Builder chatClientBuilder) {
this.chatClient = chatClientBuilder.build();
public LLMController(ChatClient chatClient, UserConversationService userConversationService, SpringAiChatMemoryService springAiChatMemoryService) {
this.chatClient = chatClient;
this.userConversationService = userConversationService;
this.springAiChatMemoryService = springAiChatMemoryService;
}

@GetMapping("/ask-ai")
LLMAnswerDTO generation(String userInput) {
LLMAnswerDTO llmAnswerDTO = new LLMAnswerDTO();
String response = this.chatClient.prompt()
.user(userInput)
.call()
.content();
@PostMapping(value = "/conversation/{conversationId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ChatStreamResponseDTO> streamGeneration(@PathVariable String conversationId, @RequestBody LLMRequest request) {
String userId = extractUserIdFromHeader();

llmAnswerDTO.setAnswer(response);
return llmAnswerDTO;
}
UserConversationModel conversation =
conversationId.equals("new")
? userConversationService.createNewConversationForUser(userId)
: userConversationService.getUserConversation(conversationId, userId);

if (conversation == null) {
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
}

@GetMapping(value = "/stream-ai", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ChatResponse> streamGeneration(@RequestParam String userInput) {
return chatClient.prompt()
.user(userInput)
.user(userSpec -> userSpec.text(request.userInput()))
.advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, conversation.getConversationId()))
.stream()
.chatResponse();
.chatResponse()
.map(chatResponse -> new ChatStreamResponseDTO(conversation.getConversationId(), chatResponse, Instant.now().toEpochMilli()));
}


// TODO: Add pagination to this endpoint
@GetMapping("/conversation/history/{conversationId}")
public List<MessageDTO> getConversationHistory(@PathVariable String conversationId) {
String userId = extractUserIdFromHeader();
UserConversationModel conversation = userConversationService.getUserConversation(conversationId, userId);

if (conversation == null) {
throw new RuntimeException("Conversation not found for user. Mismatched user or conversation ID.");
}

return springAiChatMemoryService.findAllByConversationId(conversationId).stream()
.map(chatMemoryModel -> new MessageDTO(chatMemoryModel.getContent(), chatMemoryModel.getType(), chatMemoryModel.getTimestamp()))
.toList();
}

@GetMapping("/conversation/all")
public List<ConversationDTO> getAllUserConversations() {
String userId = extractUserIdFromHeader();
List<UserConversationModel> conversations = userConversationService.getAllConversationsForUser(userId);

return conversations.stream()
.map(conv -> new ConversationDTO(conv.getConversationId()))
.toList();
}
}
5 changes: 5 additions & 0 deletions src/main/java/com/example/spring/app/llm/LLMRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.example.spring.app.llm;

public record LLMRequest(
String userInput
){}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.example.spring.app.llm.dto;

import org.springframework.ai.chat.model.ChatResponse;

public record ChatStreamResponseDTO(
String conversationId,
ChatResponse chatResponse,
Long timestamp
){}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.example.spring.app.llm.dto;

public record ConversationDTO(
String conversationId
//String title
) {}
11 changes: 11 additions & 0 deletions src/main/java/com/example/spring/app/llm/dto/MessageDTO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.example.spring.app.llm.dto;

import org.springframework.ai.chat.messages.MessageType;

import java.time.LocalDateTime;

public record MessageDTO(
String message,
MessageType messageType,
LocalDateTime timestamp
) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.example.spring.app.llm.springAiChatMemory;

import jakarta.persistence.*;
import lombok.Getter;
import lombok.Setter;
import org.hibernate.validator.constraints.Length;
import org.springframework.ai.chat.messages.MessageType;

import java.time.LocalDateTime;

@Setter
@Getter
@Entity
@Table(name = "spring_ai_chat_memory")
public class SpringAiChatMemoryModel {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Integer id;

@Length(max = 36)
private String conversationId;

@Enumerated(EnumType.STRING)
@Column(length = 10, nullable = false)
private MessageType type;

private String content;
private LocalDateTime timestamp;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.example.spring.app.llm.springAiChatMemory;

import org.springframework.data.jpa.repository.JpaRepository;

import java.util.List;

public interface SpringAiChatMemoryRepository extends JpaRepository<SpringAiChatMemoryModel, Integer> {
List<SpringAiChatMemoryModel> findAllByConversationId(String conversationId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.example.spring.app.llm.springAiChatMemory;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;

@Service
public class SpringAiChatMemoryService {
@Autowired
private SpringAiChatMemoryRepository userConversationRepository;

public List<SpringAiChatMemoryModel> findAllByConversationId(String conversationId) {
return userConversationRepository.findAllByConversationId(conversationId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.example.spring.app.llm.userConversation;

import jakarta.persistence.*;
import lombok.Getter;
import lombok.Setter;

@Setter
@Getter
@Entity
@Table(name = "user_conversation")
public class UserConversationModel {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Integer id;
private String conversationId;
private String userId;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.example.spring.app.llm.userConversation;

import org.springframework.data.jpa.repository.JpaRepository;

import java.util.List;

public interface UserConversationRepository extends JpaRepository<UserConversationModel, Integer> {
UserConversationModel findByConversationIdAndUserId(String conversationId, String userId);
List<UserConversationModel> findAllByUserId(String userId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.example.spring.app.llm.userConversation;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.UUID;

@Service
public class UserConversationService {
@Autowired
private UserConversationRepository userConversationRepository;

public UserConversationModel getUserConversation(String conversationId, String userId) {
return userConversationRepository.findByConversationIdAndUserId(conversationId, userId);
}

public UserConversationModel createNewConversationForUser(String userId) {
UserConversationModel newConversation = new UserConversationModel();
newConversation.setUserId(userId);
String conversationId = UUID.randomUUID().toString();
newConversation.setConversationId(conversationId);
return userConversationRepository.save(newConversation);
}

public List<UserConversationModel> getAllConversationsForUser(String userId) {
return userConversationRepository.findAllByUserId(userId);
}
}
Loading