/ components / execd / pkg / runtime / context_test.go
context_test.go
  1  // Copyright 2025 Alibaba Group Holding Ltd.
  2  //
  3  // Licensed under the Apache License, Version 2.0 (the "License");
  4  // you may not use this file except in compliance with the License.
  5  // You may obtain a copy of the License at
  6  //
  7  //     http://www.apache.org/licenses/LICENSE-2.0
  8  //
  9  // Unless required by applicable law or agreed to in writing, software
 10  // distributed under the License is distributed on an "AS IS" BASIS,
 11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12  // See the License for the specific language governing permissions and
 13  // limitations under the License.
 14  
 15  package runtime
 16  
 17  import (
 18  	"net/http"
 19  	"net/http/httptest"
 20  	"os"
 21  	"path/filepath"
 22  	"strings"
 23  	"testing"
 24  
 25  	"github.com/stretchr/testify/require"
 26  )
 27  
 28  func TestListContextsAndNewIpynbPath(t *testing.T) {
 29  	c := NewController("http://example", "token")
 30  	c.jupyterClientMap.Store("session-python", &jupyterKernel{language: Python})
 31  	c.defaultLanguageSessions.Store(Go, "session-go-default")
 32  
 33  	pyContexts, err := c.listLanguageContexts(Python)
 34  	require.NoError(t, err)
 35  	require.Len(t, pyContexts, 1)
 36  	require.Equal(t, "session-python", pyContexts[0].ID)
 37  	require.Equal(t, Python, pyContexts[0].Language)
 38  
 39  	allContexts, err := c.listAllContexts()
 40  	require.NoError(t, err)
 41  	require.Len(t, allContexts, 2)
 42  
 43  	tmpDir := filepath.Join(t.TempDir(), "nested")
 44  	path, err := c.newIpynbPath("abc123", tmpDir)
 45  	require.NoError(t, err)
 46  	_, statErr := os.Stat(tmpDir)
 47  	require.NoError(t, statErr, "expected directory to be created")
 48  	expected := filepath.Join(tmpDir, "abc123.ipynb")
 49  	require.Equal(t, expected, path)
 50  }
 51  
 52  func TestNewContextID_UniqueAndLength(t *testing.T) {
 53  	c := NewController("", "")
 54  	id1 := c.newContextID()
 55  	id2 := c.newContextID()
 56  
 57  	require.NotEmpty(t, id1)
 58  	require.NotEmpty(t, id2)
 59  	require.NotEqual(t, id1, id2, "expected unique ids")
 60  	require.Len(t, id1, 32)
 61  	require.Len(t, id2, 32)
 62  }
 63  
 64  func TestNewIpynbPath_ErrorWhenCwdIsFile(t *testing.T) {
 65  	c := NewController("", "")
 66  	tmpFile := filepath.Join(t.TempDir(), "file.txt")
 67  	require.NoError(t, os.WriteFile(tmpFile, []byte("x"), 0o644))
 68  
 69  	_, err := c.newIpynbPath("abc", tmpFile)
 70  	require.Error(t, err, "expected error when cwd is a file")
 71  }
 72  
 73  func TestNewIpynbPath_ExpandsHome(t *testing.T) {
 74  	home := t.TempDir()
 75  	t.Setenv("HOME", home)
 76  	t.Setenv("USERPROFILE", home)
 77  
 78  	c := NewController("", "")
 79  	path, err := c.newIpynbPath("abc", "~/workspace")
 80  	require.NoError(t, err)
 81  	require.Equal(t, filepath.Join(home, "workspace", "abc.ipynb"), path)
 82  }
 83  
 84  func TestListContextUnsupportedLanguage(t *testing.T) {
 85  	c := NewController("", "")
 86  	_, err := c.ListContext(Command.String())
 87  	require.Error(t, err, "expected error for command language")
 88  	_, err = c.ListContext(BackgroundCommand.String())
 89  	require.Error(t, err, "expected error for background-command language")
 90  	_, err = c.ListContext(SQL.String())
 91  	require.Error(t, err, "expected error for sql language")
 92  }
 93  
 94  func TestDeleteContext_NotFound(t *testing.T) {
 95  	c := NewController("", "")
 96  	err := c.DeleteContext("missing")
 97  	require.Error(t, err, "expected ErrContextNotFound")
 98  	require.ErrorIs(t, err, ErrContextNotFound)
 99  }
100  
101  func TestGetContext_NotFound(t *testing.T) {
102  	c := NewController("", "")
103  
104  	_, err := c.GetContext("missing")
105  	require.Error(t, err, "expected ErrContextNotFound")
106  	require.ErrorIs(t, err, ErrContextNotFound)
107  }
108  
109  func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) {
110  	sessionID := "sess-123"
111  
112  	// mock jupyter server that accepts DELETE
113  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
114  		require.Equal(t, http.MethodDelete, r.Method, "unexpected method")
115  		require.True(t, strings.HasSuffix(r.URL.Path, "/api/sessions/"+sessionID), "unexpected path: %s", r.URL.Path)
116  		w.WriteHeader(http.StatusNoContent)
117  	}))
118  	defer server.Close()
119  
120  	c := NewController(server.URL, "token")
121  	c.jupyterClientMap.Store(sessionID, &jupyterKernel{language: Python})
122  	c.defaultLanguageSessions.Store(Python, sessionID)
123  
124  	require.NoError(t, c.DeleteContext(sessionID))
125  
126  	require.Nil(t, c.getJupyterKernel(sessionID), "expected cache to be cleared")
127  	_, ok := c.defaultLanguageSessions.Load(Python)
128  	require.False(t, ok, "expected default session entry to be removed")
129  }
130  
131  func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) {
132  	lang := Python
133  	session1 := "sess-1"
134  	session2 := "sess-2"
135  
136  	// mock jupyter server to accept two deletes
137  	deleteCalls := make(map[string]int)
138  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
139  		require.Equal(t, http.MethodDelete, r.Method, "unexpected method")
140  		if strings.Contains(r.URL.Path, session1) {
141  			deleteCalls[session1]++
142  		} else if strings.Contains(r.URL.Path, session2) {
143  			deleteCalls[session2]++
144  		} else {
145  			require.Failf(t, "unexpected path", "%s", r.URL.Path)
146  		}
147  		w.WriteHeader(http.StatusNoContent)
148  	}))
149  	defer server.Close()
150  
151  	c := NewController(server.URL, "token")
152  	c.jupyterClientMap.Store(session1, &jupyterKernel{language: lang})
153  	c.jupyterClientMap.Store(session2, &jupyterKernel{language: lang})
154  	c.defaultLanguageSessions.Store(lang, session2)
155  
156  	require.NoError(t, c.DeleteLanguageContext(lang))
157  
158  	_, ok := c.jupyterClientMap.Load(session1)
159  	require.False(t, ok, "expected session1 removed from cache")
160  	_, ok = c.jupyterClientMap.Load(session2)
161  	require.False(t, ok, "expected session2 removed from cache")
162  	_, ok = c.defaultLanguageSessions.Load(lang)
163  	require.False(t, ok, "expected default entry removed")
164  	require.Equal(t, 1, deleteCalls[session1])
165  	require.Equal(t, 1, deleteCalls[session2])
166  }