obm/internal/inference/inference_test.go
MermaidMan 33d9a2cb2e deploy walkthrough, API validation, inference client, Hetzner provider
- Interactive deploy command with 8-step walkthrough:
  framework → provider → token → SSH → server → inference → tailscale → discord
- .env file generation from walkthrough config
- DeploymentConfig struct with framework-aware defaults
- Inference API client with validation for Venice, OpenRouter, OpenAI, Anthropic
- Hetzner Cloud provider: token validation, SSH key listing
- DotEnv parser/writer with schema validation
- Destroy command with confirmation prompt
- Validation subcommand for checking existing .env files
- All tests passing, go vet clean
2026-05-22 15:29:27 +00:00

292 lines
7.1 KiB
Go

package inference
import (
"testing"
)
func TestProviderString(t *testing.T) {
tests := []struct {
provider Provider
expected string
}{
{ProviderZAI, "zai"},
{ProviderVenice, "venice"},
{ProviderOpenRouter, "openrouter"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
if got := tt.provider.String(); got != tt.expected {
t.Errorf("Provider.String() = %q, want %q", got, tt.expected)
}
})
}
}
func TestProviderInfo(t *testing.T) {
tests := []struct {
provider Provider
expectedName string
expectedEnv string
expectedURL string
}{
{ProviderZAI, "Z.ai", "GLM_API_KEY", "https://api.z.ai/api/coding/paas/v4"},
{ProviderVenice, "Venice.ai", "VENICE_API_KEY", "https://api.venice.ai/api/v1"},
{ProviderOpenRouter, "OpenRouter", "OPENROUTER_API_KEY", "https://openrouter.ai/api/v1"},
}
for _, tt := range tests {
t.Run(tt.expectedName, func(t *testing.T) {
name, env, url := tt.provider.Info()
if name != tt.expectedName {
t.Errorf("Provider.Info() name = %q, want %q", name, tt.expectedName)
}
if env != tt.expectedEnv {
t.Errorf("Provider.Info() env = %q, want %q", env, tt.expectedEnv)
}
if url != tt.expectedURL {
t.Errorf("Provider.Info() url = %q, want %q", url, tt.expectedURL)
}
})
}
}
func TestFallbackChain(t *testing.T) {
tests := []struct {
provider Provider
expectedChainLen int
}{
{ProviderZAI, 3}, // ZAI -> Venice -> OpenRouter
{ProviderVenice, 2}, // Venice -> OpenRouter
{ProviderOpenRouter, 1}, // OpenRouter (no fallback)
}
for _, tt := range tests {
t.Run(tt.provider.String(), func(t *testing.T) {
chain := tt.provider.FallbackChain()
if len(chain) != tt.expectedChainLen {
t.Errorf("FallbackChain() length = %d, want %d", len(chain), tt.expectedChainLen)
}
// First element should be the provider itself
if len(chain) > 0 && chain[0] != tt.provider {
t.Errorf("FallbackChain()[0] = %v, want %v", chain[0], tt.provider)
}
})
}
}
func TestProviderUnmarshal(t *testing.T) {
tests := []struct {
input string
expected Provider
expectError bool
}{
{"zai", ProviderZAI, false},
{"z.ai", ProviderZAI, false},
{"ZAI", ProviderZAI, false},
{"venice", ProviderVenice, false},
{"venice.ai", ProviderVenice, false},
{"Venice", ProviderVenice, false},
{"openrouter", ProviderOpenRouter, false},
{"open-router", ProviderOpenRouter, false},
{"OpenRouter", ProviderOpenRouter, false},
{"unknown", Provider(""), true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
var p Provider
err := p.UnmarshalText([]byte(tt.input))
if tt.expectError {
if err == nil {
t.Error("UnmarshalText() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("UnmarshalText() unexpected error: %v", err)
}
if p != tt.expected {
t.Errorf("UnmarshalText() = %v, want %v", p, tt.expected)
}
}
})
}
}
func TestNewProviderSelection(t *testing.T) {
selection := NewProviderSelection(ProviderZAI)
if selection.Primary != ProviderZAI {
t.Errorf("NewProviderSelection() Primary = %v, want %v", selection.Primary, ProviderZAI)
}
if selection.Model != "glm-5.1" {
t.Errorf("NewProviderSelection() Model = %q, want %q", selection.Model, "glm-5.1")
}
if selection.MaxTokens != 16384 {
t.Errorf("NewProviderSelection() MaxTokens = %d, want %d", selection.MaxTokens, 16384)
}
// Verify fallback chain
if len(selection.FallbackChain) != 3 {
t.Errorf("NewProviderSelection() FallbackChain length = %d, want %d", len(selection.FallbackChain), 3)
}
}
func TestProviderSelectionValidate(t *testing.T) {
tests := []struct {
name string
selection *ProviderSelection
expectError bool
}{
{
name: "valid selection",
selection: &ProviderSelection{
Primary: ProviderZAI,
FallbackChain: []Provider{ProviderZAI, ProviderVenice, ProviderOpenRouter},
Model: "glm-5.1",
MaxTokens: 16384,
},
expectError: false,
},
{
name: "zero max_tokens",
selection: &ProviderSelection{
Primary: ProviderZAI,
Model: "glm-5.1",
MaxTokens: 0,
},
expectError: true,
},
{
name: "negative max_tokens",
selection: &ProviderSelection{
Primary: ProviderZAI,
Model: "glm-5.1",
MaxTokens: -100,
},
expectError: true,
},
{
name: "excessive max_tokens",
selection: &ProviderSelection{
Primary: ProviderZAI,
Model: "glm-5.1",
MaxTokens: 200000,
},
expectError: true,
},
{
name: "empty model",
selection: &ProviderSelection{
Primary: ProviderZAI,
Model: "",
MaxTokens: 16384,
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.selection.Validate()
if tt.expectError {
if err == nil {
t.Error("Validate() expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Validate() unexpected error: %v", err)
}
}
})
}
}
func TestAPIKeyEnvVars(t *testing.T) {
tests := []struct {
name string
providers []Provider
expectedCount int
}{
{
name: "single provider",
providers: []Provider{ProviderZAI},
expectedCount: 1,
},
{
name: "ZAI fallback chain",
providers: []Provider{ProviderZAI, ProviderVenice, ProviderOpenRouter},
expectedCount: 3,
},
{
name: "Venice only",
providers: []Provider{ProviderVenice},
expectedCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vars := APIKeyEnvVars(tt.providers...)
if len(vars) != tt.expectedCount {
t.Errorf("APIKeyEnvVars() length = %d, want %d", len(vars), tt.expectedCount)
}
})
}
}
func TestDefaultGLMConfig(t *testing.T) {
config := DefaultGLMConfig()
if config.Provider != ProviderZAI {
t.Errorf("DefaultGLMConfig() Provider = %v, want %v", config.Provider, ProviderZAI)
}
if config.Model != "glm-5.1" {
t.Errorf("DefaultGLMConfig() Model = %q, want %q", config.Model, "glm-5.1")
}
if config.MaxTokens != 16384 {
t.Errorf("DefaultGLMConfig() MaxTokens = %d, want %d", config.MaxTokens, 16384)
}
}
func TestGetProviderOptions(t *testing.T) {
options := GetProviderOptions()
if len(options) != 3 {
t.Errorf("GetProviderOptions() length = %d, want %d", len(options), 3)
}
// Check Z.ai is marked as recommended
foundZAI := false
for _, opt := range options {
if opt.Provider == ProviderZAI {
foundZAI = true
if !opt.Recommended {
t.Error("Z.ai should be marked as recommended")
}
}
}
if !foundZAI {
t.Error("Z.ai provider not found in options")
}
}
func TestFormatProviderList(t *testing.T) {
list := FormatProviderList()
// Check that it contains expected provider names
expectedStrings := []string{"Z.ai", "Venice.ai", "OpenRouter", "recommended"}
for _, s := range expectedStrings {
if !contains(list, s) {
t.Errorf("FormatProviderList() missing expected string %q", s)
}
}
}
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}