diff --git a/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.test.tsx b/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.test.tsx new file mode 100644 index 0000000000..4cc9048d17 --- /dev/null +++ b/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.test.tsx @@ -0,0 +1,38 @@ +import React from "react" +import { renderWithProviders } from "@/test-utils" +import AiRecommendationBotDrawer from "./AiRecommendationBotDrawer" +import { RECOMMENDER_QUERY_PARAM } from "@/common/urls" + +const mockAiChat = jest.fn]>(() => ( +
+)) +jest.mock("@mitodl/smoot-design/ai", () => { + const actual = jest.requireActual("@mitodl/smoot-design/ai") + return { + ...actual, + AiChat: (props: Record) => { + mockAiChat(props) + return actual.AiChat(props) + }, + } +}) + +describe("AiRecommendationBotDrawer", () => { + test("passes CSRF config in requestOpts", () => { + mockAiChat.mockClear() + + renderWithProviders(, { + url: `?${RECOMMENDER_QUERY_PARAM}`, + }) + + const call = mockAiChat.mock.calls[0][0] as Record + const requestOpts = call.requestOpts as Record + expect(requestOpts.csrfCookieName).toBe( + process.env.NEXT_PUBLIC_CSRF_COOKIE_NAME || "csrftoken", + ) + expect(requestOpts.csrfHeaderName).toBe("X-CSRFToken") + expect((requestOpts.fetchOpts as Record).credentials).toBe( + "include", + ) + }) +}) diff --git a/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.tsx b/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.tsx index e84272bd0b..150fe86e9d 100644 --- a/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.tsx +++ b/frontends/main/src/page-components/AiChat/AiRecommendationBotDrawer.tsx @@ -3,7 +3,6 @@ import { styled, RoutedDrawer } from "ol-components" import { RiCloseLine } from "@remixicon/react" import { ActionButton } from "@mitodl/smoot-design" import { AiChat } from "@mitodl/smoot-design/ai" -import { getCsrfToken } from "@/common/client-utils" import { RECOMMENDER_QUERY_PARAM } from "@/common/urls" const CloseButtonContainer = styled("div")({ @@ -72,10 +71,10 @@ const DrawerContent: React.FC<{ scrollElement={scrollElement} requestOpts={{ apiUrl: process.env.NEXT_PUBLIC_LEARN_AI_RECOMMENDATION_ENDPOINT!, + csrfCookieName: + process.env.NEXT_PUBLIC_CSRF_COOKIE_NAME || "csrftoken", + csrfHeaderName: "X-CSRFToken", fetchOpts: { - headers: { - "X-CSRFToken": getCsrfToken(), - }, credentials: "include", }, transformBody: (messages) => ({ diff --git a/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.test.tsx b/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.test.tsx index 74bbc2b978..8d3e93fa23 100644 --- a/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.test.tsx +++ b/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.test.tsx @@ -6,6 +6,20 @@ import { renderWithProviders, screen, user } from "@/test-utils" import { factories } from "api/test-utils" import invariant from "tiny-invariant" +const mockAiChat = jest.fn]>(() => ( +
+)) +jest.mock("@mitodl/smoot-design/ai", () => { + const actual = jest.requireActual("@mitodl/smoot-design/ai") + return { + ...actual, + AiChat: (props: Record) => { + mockAiChat(props) + return actual.AiChat(props) + }, + } +}) + /** * Note: This component is primarily tested in @mitodl/smoot-design. * @@ -35,4 +49,30 @@ describe("AiChatSyllabus", () => { invariant(firstMessage instanceof HTMLElement) expect(firstMessage?.dataset.chatRole).toBe("user") }) + + test("passes CSRF config in requestOpts", () => { + const resource = factories.learningResources.course() + mockAiChat.mockClear() + + renderWithProviders( + , + ) + + const call = mockAiChat.mock.calls[0][0] as Record + const requestOpts = call.requestOpts as Record + expect(requestOpts.csrfCookieName).toBe( + process.env.NEXT_PUBLIC_CSRF_COOKIE_NAME || "csrftoken", + ) + expect(requestOpts.csrfHeaderName).toBe("X-CSRFToken") + expect((requestOpts.fetchOpts as Record).credentials).toBe( + "include", + ) + }) }) diff --git a/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.tsx b/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.tsx index ccd201386d..6b4d12417d 100644 --- a/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.tsx +++ b/frontends/main/src/page-components/LearningResourceExpanded/AiChatSyllabusSlideDown.tsx @@ -5,7 +5,6 @@ import { RiSparkling2Line, RiArrowDownSLine } from "@remixicon/react" import type { AiChatProps } from "@mitodl/smoot-design/ai" import { LearningResource } from "api" import { AiChat } from "@mitodl/smoot-design/ai" -import { getCsrfToken } from "@/common/client-utils" export enum ChatTransitionState { Closed = "Closed", @@ -198,10 +197,10 @@ const AiChatSyllabusSlideDown = ({ scrollElement={scrollElement} requestOpts={{ apiUrl: process.env.NEXT_PUBLIC_LEARN_AI_SYLLABUS_ENDPOINT!, + csrfCookieName: + process.env.NEXT_PUBLIC_CSRF_COOKIE_NAME || "csrftoken", + csrfHeaderName: "X-CSRFToken", fetchOpts: { - headers: { - "X-CSRFToken": getCsrfToken(), - }, credentials: "include", }, transformBody: (messages) => {