ScuttleBot
| 5ac549c… | lmata | 1 | package llm |
| 5ac549c… | lmata | 2 | |
| 5ac549c… | lmata | 3 | import ( |
| 5ac549c… | lmata | 4 | "context" |
| 5ac549c… | lmata | 5 | "encoding/json" |
| 5ac549c… | lmata | 6 | "net/http" |
| 5ac549c… | lmata | 7 | "net/http/httptest" |
| 5ac549c… | lmata | 8 | "testing" |
| 5ac549c… | lmata | 9 | ) |
| 5ac549c… | lmata | 10 | |
| 5ac549c… | lmata | 11 | func TestOllamaSummarize(t *testing.T) { |
| 5ac549c… | lmata | 12 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 5ac549c… | lmata | 13 | if r.Method != "POST" { |
| 5ac549c… | lmata | 14 | t.Errorf("expected POST request, got %s", r.Method) |
| 5ac549c… | lmata | 15 | } |
| 5ac549c… | lmata | 16 | if r.URL.Path != "/api/generate" { |
| 5ac549c… | lmata | 17 | t.Errorf("unexpected path: %s", r.URL.Path) |
| 5ac549c… | lmata | 18 | } |
| 5ac549c… | lmata | 19 | |
| 5ac549c… | lmata | 20 | var req struct { |
| 5ac549c… | lmata | 21 | Model string `json:"model"` |
| 5ac549c… | lmata | 22 | Prompt string `json:"prompt"` |
| 5ac549c… | lmata | 23 | Stream bool `json:"stream"` |
| 5ac549c… | lmata | 24 | } |
| 5ac549c… | lmata | 25 | if err := json.NewDecoder(r.Body).Decode(&req); err != nil { |
| 5ac549c… | lmata | 26 | t.Fatalf("decode request: %v", err) |
| 5ac549c… | lmata | 27 | } |
| 5ac549c… | lmata | 28 | |
| 5ac549c… | lmata | 29 | resp := map[string]any{ |
| 5ac549c… | lmata | 30 | "response": "ollama response", |
| 5ac549c… | lmata | 31 | } |
| 5ac549c… | lmata | 32 | _ = json.NewEncoder(w).Encode(resp) |
| 5ac549c… | lmata | 33 | })) |
| 5ac549c… | lmata | 34 | defer srv.Close() |
| 5ac549c… | lmata | 35 | |
| 5ac549c… | lmata | 36 | p := newOllamaProvider(BackendConfig{ |
| 5ac549c… | lmata | 37 | Backend: "ollama", |
| 5ac549c… | lmata | 38 | Model: "test-model", |
| 5ac549c… | lmata | 39 | }, srv.URL, srv.Client()) |
| 5ac549c… | lmata | 40 | |
| 5ac549c… | lmata | 41 | got, err := p.Summarize(context.Background(), "test prompt") |
| 5ac549c… | lmata | 42 | if err != nil { |
| 5ac549c… | lmata | 43 | t.Fatalf("Summarize failed: %v", err) |
| 5ac549c… | lmata | 44 | } |
| 5ac549c… | lmata | 45 | if got != "ollama response" { |
| 5ac549c… | lmata | 46 | t.Errorf("got %q, want %q", got, "ollama response") |
| 5ac549c… | lmata | 47 | } |
| 5ac549c… | lmata | 48 | } |
| 5ac549c… | lmata | 49 | |
| 5ac549c… | lmata | 50 | func TestOllamaDiscoverModels(t *testing.T) { |
| 5ac549c… | lmata | 51 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 5ac549c… | lmata | 52 | if r.Method != "GET" { |
| 5ac549c… | lmata | 53 | t.Errorf("expected GET request, got %s", r.Method) |
| 5ac549c… | lmata | 54 | } |
| 5ac549c… | lmata | 55 | if r.URL.Path != "/api/tags" { |
| 5ac549c… | lmata | 56 | t.Errorf("unexpected path: %s", r.URL.Path) |
| 5ac549c… | lmata | 57 | } |
| 5ac549c… | lmata | 58 | |
| 5ac549c… | lmata | 59 | resp := map[string]any{ |
| 5ac549c… | lmata | 60 | "models": []map[string]any{ |
| 5ac549c… | lmata | 61 | {"name": "model1"}, |
| 5ac549c… | lmata | 62 | {"name": "model2"}, |
| 5ac549c… | lmata | 63 | }, |
| 5ac549c… | lmata | 64 | } |
| 5ac549c… | lmata | 65 | _ = json.NewEncoder(w).Encode(resp) |
| 5ac549c… | lmata | 66 | })) |
| 5ac549c… | lmata | 67 | defer srv.Close() |
| 5ac549c… | lmata | 68 | |
| 5ac549c… | lmata | 69 | p := newOllamaProvider(BackendConfig{ |
| 5ac549c… | lmata | 70 | Backend: "ollama", |
| 5ac549c… | lmata | 71 | }, srv.URL, srv.Client()) |
| 5ac549c… | lmata | 72 | |
| 5ac549c… | lmata | 73 | models, err := p.DiscoverModels(context.Background()) |
| 5ac549c… | lmata | 74 | if err != nil { |
| 5ac549c… | lmata | 75 | t.Fatalf("DiscoverModels failed: %v", err) |
| 5ac549c… | lmata | 76 | } |
| 5ac549c… | lmata | 77 | |
| 5ac549c… | lmata | 78 | if len(models) != 2 { |
| 5ac549c… | lmata | 79 | t.Errorf("got %d models, want 2", len(models)) |
| 5ac549c… | lmata | 80 | } |
| 5ac549c… | lmata | 81 | } |