| | @@ -0,0 +1,134 @@ |
| 1 | +package llm
|
| 2 | +
|
| 3 | +import (
|
| 4 | + "testing"
|
| 5 | +)
|
| 6 | +
|
| 7 | +func models(ids ...string) []ModelInfo {
|
| 8 | + out := make([]ModelInfo, len(ids))
|
| 9 | + for i, id := range ids {
|
| 10 | + out[i] = ModelInfo{ID: id, Name: id}
|
| 11 | + }
|
| 12 | + return out
|
| 13 | +}
|
| 14 | +
|
| 15 | +func ids(ms []ModelInfo) []string {
|
| 16 | + out := make([]string, len(ms))
|
| 17 | + for i, m := range ms {
|
| 18 | + out[i] = m.ID
|
| 19 | + }
|
| 20 | + return out
|
| 21 | +}
|
| 22 | +
|
| 23 | +func TestNewModelFilterInvalidAllow(t *testing.T) {
|
| 24 | + _, err := NewModelFilter([]string{"["}, nil)
|
| 25 | + if err == nil {
|
| 26 | + t.Fatal("expected error for invalid allow pattern, got nil")
|
| 27 | + }
|
| 28 | +}
|
| 29 | +
|
| 30 | +func TestNewModelFilterInvalidBlock(t *testing.T) {
|
| 31 | + _, err := NewModelFilter(nil, []string{"["})
|
| 32 | + if err == nil {
|
| 33 | + t.Fatal("expected error for invalid block pattern, got nil")
|
| 34 | + }
|
| 35 | +}
|
| 36 | +
|
| 37 | +func TestFilterNoPatterns(t *testing.T) {
|
| 38 | + f, err := NewModelFilter(nil, nil)
|
| 39 | + if err != nil {
|
| 40 | + t.Fatal(err)
|
| 41 | + }
|
| 42 | + input := models("gpt-4", "claude-3", "gemini-pro")
|
| 43 | + got := f.Apply(input)
|
| 44 | + if len(got) != len(input) {
|
| 45 | + t.Errorf("no patterns: got %d models, want %d", len(got), len(input))
|
| 46 | + }
|
| 47 | +}
|
| 48 | +
|
| 49 | +func TestFilterAllowOnly(t *testing.T) {
|
| 50 | + f, err := NewModelFilter([]string{"^claude"}, nil)
|
| 51 | + if err != nil {
|
| 52 | + t.Fatal(err)
|
| 53 | + }
|
| 54 | + got := f.Apply(models("claude-3-sonnet", "gpt-4", "claude-haiku", "gemini-pro"))
|
| 55 | + gotIDs := ids(got)
|
| 56 | + want := []string{"claude-3-sonnet", "claude-haiku"}
|
| 57 | + if len(gotIDs) != len(want) {
|
| 58 | + t.Fatalf("allow-only: got %v, want %v", gotIDs, want)
|
| 59 | + }
|
| 60 | + for i, id := range gotIDs {
|
| 61 | + if id != want[i] {
|
| 62 | + t.Errorf("allow-only[%d]: got %q, want %q", i, id, want[i])
|
| 63 | + }
|
| 64 | + }
|
| 65 | +}
|
| 66 | +
|
| 67 | +func TestFilterBlockOnly(t *testing.T) {
|
| 68 | + f, err := NewModelFilter(nil, []string{"preview", "legacy"})
|
| 69 | + if err != nil {
|
| 70 | + t.Fatal(err)
|
| 71 | + }
|
| 72 | + got := f.Apply(models("gpt-4", "gpt-4-preview", "claude-3", "claude-legacy"))
|
| 73 | + gotIDs := ids(got)
|
| 74 | + want := []string{"gpt-4", "claude-3"}
|
| 75 | + if len(gotIDs) != len(want) {
|
| 76 | + t.Fatalf("block-only: got %v, want %v", gotIDs, want)
|
| 77 | + }
|
| 78 | + for i, id := range gotIDs {
|
| 79 | + if id != want[i] {
|
| 80 | + t.Errorf("block-only[%d]: got %q, want %q", i, id, want[i])
|
| 81 | + }
|
| 82 | + }
|
| 83 | +}
|
| 84 | +
|
| 85 | +func TestFilterAllowAndBlock(t *testing.T) {
|
| 86 | + // Allow claude-*, block anything with "legacy".
|
| 87 | + f, err := NewModelFilter([]string{"^claude"}, []string{"legacy"})
|
| 88 | + if err != nil {
|
| 89 | + t.Fatal(err)
|
| 90 | + }
|
| 91 | + got := f.Apply(models("claude-3", "claude-legacy", "gpt-4", "gemini"))
|
| 92 | + gotIDs := ids(got)
|
| 93 | + // Only claude-3 survives: claude-legacy is blocked, gpt-4/gemini not in allowlist.
|
| 94 | + if len(gotIDs) != 1 || gotIDs[0] != "claude-3" {
|
| 95 | + t.Errorf("allow+block: got %v, want [claude-3]", gotIDs)
|
| 96 | + }
|
| 97 | +}
|
| 98 | +
|
| 99 | +func TestFilterEmptyInput(t *testing.T) {
|
| 100 | + f, err := NewModelFilter([]string{"^claude"}, []string{"legacy"})
|
| 101 | + if err != nil {
|
| 102 | + t.Fatal(err)
|
| 103 | + }
|
| 104 | + got := f.Apply(nil)
|
| 105 | + if len(got) != 0 {
|
| 106 | + t.Errorf("empty input: got %d models, want 0", len(got))
|
| 107 | + }
|
| 108 | +}
|
| 109 | +
|
| 110 | +func TestFilterBlockTakesPrecedenceOverAllow(t *testing.T) {
|
| 111 | + // Pattern matches both allow and block — block wins.
|
| 112 | + f, err := NewModelFilter([]string{"claude"}, []string{"claude-3"})
|
| 113 | + if err != nil {
|
| 114 | + t.Fatal(err)
|
| 115 | + }
|
| 116 | + got := f.Apply(models("claude-3", "claude-haiku"))
|
| 117 | + gotIDs := ids(got)
|
| 118 | + // claude-3 is blocked; claude-haiku passes allowlist.
|
| 119 | + if len(gotIDs) != 1 || gotIDs[0] != "claude-haiku" {
|
| 120 | + t.Errorf("block-over-allow: got %v, want [claude-haiku]", gotIDs)
|
| 121 | + }
|
| 122 | +}
|
| 123 | +
|
| 124 | +func TestFilterMultipleAllowPatterns(t *testing.T) {
|
| 125 | + f, err := NewModelFilter([]string{"^claude", "^gemini"}, nil)
|
| 126 | + if err != nil {
|
| 127 | + t.Fatal(err)
|
| 128 | + }
|
| 129 | + got := f.Apply(models("claude-3", "gpt-4", "gemini-pro", "llama"))
|
| 130 | + gotIDs := ids(got)
|
| 131 | + if len(gotIDs) != 2 {
|
| 132 | + t.Fatalf("multi-allow: got %v, want [claude-3 gemini-pro]", gotIDs)
|
| 133 | + }
|
| 134 | +}
|