diff --git a/components/execd/pkg/runtime/context.go b/components/execd/pkg/runtime/context.go index a1135507..c2f9052a 100644 --- a/components/execd/pkg/runtime/context.go +++ b/components/execd/pkg/runtime/context.go @@ -68,12 +68,15 @@ func (c *Controller) DeleteContext(session string) error { return c.deleteSessionAndCleanup(session) } -func (c *Controller) GetContext(session string) CodeContext { +func (c *Controller) GetContext(session string) (CodeContext, error) { kernel := c.getJupyterKernel(session) + if kernel == nil { + return CodeContext{}, ErrContextNotFound + } return CodeContext{ ID: session, Language: kernel.language, - } + }, nil } func (c *Controller) ListContext(language string) ([]CodeContext, error) { diff --git a/components/execd/pkg/runtime/context_test.go b/components/execd/pkg/runtime/context_test.go index 6a27ad18..34eb956c 100644 --- a/components/execd/pkg/runtime/context_test.go +++ b/components/execd/pkg/runtime/context_test.go @@ -112,6 +112,18 @@ func TestDeleteContext_NotFound(t *testing.T) { } } +func TestGetContext_NotFound(t *testing.T) { + c := NewController("", "") + + _, err := c.GetContext("missing") + if err == nil { + t.Fatalf("expected ErrContextNotFound") + } + if !errors.Is(err, ErrContextNotFound) { + t.Fatalf("unexpected error: %v", err) + } +} + func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { sessionID := "sess-123" diff --git a/components/execd/pkg/web/controller/codeinterpreting.go b/components/execd/pkg/web/controller/codeinterpreting.go index 4959ce97..df4a28db 100644 --- a/components/execd/pkg/web/controller/codeinterpreting.go +++ b/components/execd/pkg/web/controller/codeinterpreting.go @@ -137,9 +137,26 @@ func (c *CodeInterpretingController) GetContext() { model.ErrorCodeMissingQuery, "missing path parameter 'contextId'", ) + return } - codeContext := codeRunner.GetContext(contextID) + codeContext, err := codeRunner.GetContext(contextID) + if err != nil { + if errors.Is(err, runtime.ErrContextNotFound) { + c.RespondError( + http.StatusNotFound, + model.ErrorCodeContextNotFound, + fmt.Sprintf("context %s not found", contextID), + ) + return + } + c.RespondError( + http.StatusInternalServerError, + model.ErrorCodeRuntimeError, + fmt.Sprintf("error getting code context %s. %v", contextID, err), + ) + return + } c.RespondSuccess(codeContext) } diff --git a/components/execd/pkg/web/controller/codeinterpreting_test.go b/components/execd/pkg/web/controller/codeinterpreting_test.go index 53b8e6bc..6b2b7ead 100644 --- a/components/execd/pkg/web/controller/codeinterpreting_test.go +++ b/components/execd/pkg/web/controller/codeinterpreting_test.go @@ -15,8 +15,12 @@ package controller import ( + "encoding/json" + "net/http" "testing" + "github.com/gin-gonic/gin" + "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" ) @@ -59,3 +63,52 @@ func TestBuildExecuteCodeRequestRespectsLanguage(t *testing.T) { t.Fatalf("expected python language, got %s", execReq.Language) } } + +func TestGetContext_NotFoundReturns404(t *testing.T) { + ctx, w := newTestContext(http.MethodGet, "/code/contexts/missing", nil) + ctx.Params = append(ctx.Params, gin.Param{Key: "contextId", Value: "missing"}) + ctrl := NewCodeInterpretingController(ctx) + + previous := codeRunner + codeRunner = runtime.NewController("", "") + t.Cleanup(func() { codeRunner = previous }) + + ctrl.GetContext() + + if w.Code != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, w.Code) + } + + var resp model.ErrorResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Code != model.ErrorCodeContextNotFound { + t.Fatalf("unexpected error code: %s", resp.Code) + } + if resp.Message != "context missing not found" { + t.Fatalf("unexpected message: %s", resp.Message) + } +} + +func TestGetContext_MissingIDReturns400(t *testing.T) { + ctx, w := newTestContext(http.MethodGet, "/code/contexts/", nil) + ctrl := NewCodeInterpretingController(ctx) + + ctrl.GetContext() + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + + var resp model.ErrorResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Code != model.ErrorCodeMissingQuery { + t.Fatalf("unexpected error code: %s", resp.Code) + } + if resp.Message != "missing path parameter 'contextId'" { + t.Fatalf("unexpected message: %s", resp.Message) + } +}