package inference import ( "testing" ) func TestProviderString(t *testing.T) { tests := []struct { provider Provider expected string }{ {ProviderZAI, "zai"}, {ProviderVenice, "venice"}, {ProviderOpenRouter, "openrouter"}, } for _, tt := range tests { t.Run(tt.expected, func(t *testing.T) { if got := tt.provider.String(); got != tt.expected { t.Errorf("Provider.String() = %q, want %q", got, tt.expected) } }) } } func TestProviderInfo(t *testing.T) { tests := []struct { provider Provider expectedName string expectedEnv string expectedURL string }{ {ProviderZAI, "Z.ai", "GLM_API_KEY", "https://api.z.ai/api/coding/paas/v4"}, {ProviderVenice, "Venice.ai", "VENICE_API_KEY", "https://api.venice.ai/api/v1"}, {ProviderOpenRouter, "OpenRouter", "OPENROUTER_API_KEY", "https://openrouter.ai/api/v1"}, } for _, tt := range tests { t.Run(tt.expectedName, func(t *testing.T) { name, env, url := tt.provider.Info() if name != tt.expectedName { t.Errorf("Provider.Info() name = %q, want %q", name, tt.expectedName) } if env != tt.expectedEnv { t.Errorf("Provider.Info() env = %q, want %q", env, tt.expectedEnv) } if url != tt.expectedURL { t.Errorf("Provider.Info() url = %q, want %q", url, tt.expectedURL) } }) } } func TestFallbackChain(t *testing.T) { tests := []struct { provider Provider expectedChainLen int }{ {ProviderZAI, 3}, // ZAI -> Venice -> OpenRouter {ProviderVenice, 2}, // Venice -> OpenRouter {ProviderOpenRouter, 1}, // OpenRouter (no fallback) } for _, tt := range tests { t.Run(tt.provider.String(), func(t *testing.T) { chain := tt.provider.FallbackChain() if len(chain) != tt.expectedChainLen { t.Errorf("FallbackChain() length = %d, want %d", len(chain), tt.expectedChainLen) } // First element should be the provider itself if len(chain) > 0 && chain[0] != tt.provider { t.Errorf("FallbackChain()[0] = %v, want %v", chain[0], tt.provider) } }) } } func TestProviderUnmarshal(t *testing.T) { tests := []struct { input string expected Provider expectError bool }{ {"zai", ProviderZAI, false}, {"z.ai", ProviderZAI, false}, {"ZAI", ProviderZAI, false}, {"venice", ProviderVenice, false}, {"venice.ai", ProviderVenice, false}, {"Venice", ProviderVenice, false}, {"openrouter", ProviderOpenRouter, false}, {"open-router", ProviderOpenRouter, false}, {"OpenRouter", ProviderOpenRouter, false}, {"unknown", Provider(""), true}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { var p Provider err := p.UnmarshalText([]byte(tt.input)) if tt.expectError { if err == nil { t.Error("UnmarshalText() expected error, got nil") } } else { if err != nil { t.Errorf("UnmarshalText() unexpected error: %v", err) } if p != tt.expected { t.Errorf("UnmarshalText() = %v, want %v", p, tt.expected) } } }) } } func TestNewProviderSelection(t *testing.T) { selection := NewProviderSelection(ProviderZAI) if selection.Primary != ProviderZAI { t.Errorf("NewProviderSelection() Primary = %v, want %v", selection.Primary, ProviderZAI) } if selection.Model != "glm-5.1" { t.Errorf("NewProviderSelection() Model = %q, want %q", selection.Model, "glm-5.1") } if selection.MaxTokens != 16384 { t.Errorf("NewProviderSelection() MaxTokens = %d, want %d", selection.MaxTokens, 16384) } // Verify fallback chain if len(selection.FallbackChain) != 3 { t.Errorf("NewProviderSelection() FallbackChain length = %d, want %d", len(selection.FallbackChain), 3) } } func TestProviderSelectionValidate(t *testing.T) { tests := []struct { name string selection *ProviderSelection expectError bool }{ { name: "valid selection", selection: &ProviderSelection{ Primary: ProviderZAI, FallbackChain: []Provider{ProviderZAI, ProviderVenice, ProviderOpenRouter}, Model: "glm-5.1", MaxTokens: 16384, }, expectError: false, }, { name: "zero max_tokens", selection: &ProviderSelection{ Primary: ProviderZAI, Model: "glm-5.1", MaxTokens: 0, }, expectError: true, }, { name: "negative max_tokens", selection: &ProviderSelection{ Primary: ProviderZAI, Model: "glm-5.1", MaxTokens: -100, }, expectError: true, }, { name: "excessive max_tokens", selection: &ProviderSelection{ Primary: ProviderZAI, Model: "glm-5.1", MaxTokens: 200000, }, expectError: true, }, { name: "empty model", selection: &ProviderSelection{ Primary: ProviderZAI, Model: "", MaxTokens: 16384, }, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.selection.Validate() if tt.expectError { if err == nil { t.Error("Validate() expected error, got nil") } } else { if err != nil { t.Errorf("Validate() unexpected error: %v", err) } } }) } } func TestAPIKeyEnvVars(t *testing.T) { tests := []struct { name string providers []Provider expectedCount int }{ { name: "single provider", providers: []Provider{ProviderZAI}, expectedCount: 1, }, { name: "ZAI fallback chain", providers: []Provider{ProviderZAI, ProviderVenice, ProviderOpenRouter}, expectedCount: 3, }, { name: "Venice only", providers: []Provider{ProviderVenice}, expectedCount: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { vars := APIKeyEnvVars(tt.providers...) if len(vars) != tt.expectedCount { t.Errorf("APIKeyEnvVars() length = %d, want %d", len(vars), tt.expectedCount) } }) } } func TestDefaultGLMConfig(t *testing.T) { config := DefaultGLMConfig() if config.Provider != ProviderZAI { t.Errorf("DefaultGLMConfig() Provider = %v, want %v", config.Provider, ProviderZAI) } if config.Model != "glm-5.1" { t.Errorf("DefaultGLMConfig() Model = %q, want %q", config.Model, "glm-5.1") } if config.MaxTokens != 16384 { t.Errorf("DefaultGLMConfig() MaxTokens = %d, want %d", config.MaxTokens, 16384) } } func TestGetProviderOptions(t *testing.T) { options := GetProviderOptions() if len(options) != 3 { t.Errorf("GetProviderOptions() length = %d, want %d", len(options), 3) } // Check Z.ai is marked as recommended foundZAI := false for _, opt := range options { if opt.Provider == ProviderZAI { foundZAI = true if !opt.Recommended { t.Error("Z.ai should be marked as recommended") } } } if !foundZAI { t.Error("Z.ai provider not found in options") } } func TestFormatProviderList(t *testing.T) { list := FormatProviderList() // Check that it contains expected provider names expectedStrings := []string{"Z.ai", "Venice.ai", "OpenRouter", "recommended"} for _, s := range expectedStrings { if !contains(list, s) { t.Errorf("FormatProviderList() missing expected string %q", s) } } } func contains(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }