// Package inference provides API client functionality for inference providers. package inference import ( "context" "encoding/json" "fmt" "io" "net/http" "os" "strings" "time" ) // Client provides HTTP client functionality forinference providers. type Client struct { httpClient *http.Client timeout time.Duration } // NewClient creates a new inference client with default settings. func NewClient() *Client { return &Client{ httpClient: &http.Client{ Timeout: 30 * time.Second, }, timeout: 30 * time.Second, } } // NewClientWithTimeout creates a new inference client with custom timeout. func NewClientWithTimeout(timeout time.Duration) *Client { return &Client{ httpClient: &http.Client{ Timeout: timeout, }, timeout: timeout, } } // ValidationResult contains the result of an API validation attempt. type ValidationResult struct { Provider Provider `json:"provider"` Valid bool `json:"valid"` ErrorMessage string `json:"error_message,omitempty"` ModelCount int `json:"model_count,omitempty"` Latency int64 `json:"latency_ms"` } // ValidateAPIKey validates an API key for a provider by making a test request. // Returns true if the API key is valid, false otherwise. func (c *Client) ValidateAPIKey(ctx context.Context, provider Provider, apiKey string) (*ValidationResult, error) { start := time.Now() result := &ValidationResult{ Provider: provider, } // Get provider info name, envVar, baseURL := provider.Info() if baseURL == "" { result.ErrorMessage = fmt.Sprintf("unknown provider: %s", provider) return result, fmt.Errorf("unknown provider: %s", provider) } // Use provided API key orfall back to environment variable if apiKey == "" { apiKey = os.Getenv(envVar) } if apiKey == "" { result.ErrorMessage = fmt.Sprintf("%s API key not set (set %s)", name, envVar) return result, nil } // Make a test request to list models (validates auth without consuming tokens) url := fmt.Sprintf("%s/models", baseURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { result.ErrorMessage = fmt.Sprintf("failed to create request: %v", err) return result, err } // Set auth headers based on provider c.setAuthHeaders(req, provider, apiKey) resp, err := c.httpClient.Do(req) if err != nil { result.ErrorMessage = fmt.Sprintf("request failed: %v", err) result.Latency = time.Since(start).Milliseconds() return result, nil } defer resp.Body.Close() result.Latency = time.Since(start).Milliseconds() if resp.StatusCode == http.StatusOK { // Parse response to count models body, err := io.ReadAll(resp.Body) if err == nil { var modelsResp ModelsResponse if json.Unmarshal(body, &modelsResp) == nil { result.ModelCount = len(modelsResp.Data) } } result.Valid = true return result, nil } // Handle specific error codes switch resp.StatusCode { case http.StatusUnauthorized: result.ErrorMessage = "invalid API key" case http.StatusForbidden: result.ErrorMessage = "API key lacks required permissions" case http.StatusTooManyRequests: result.ErrorMessage = "rate limited - key is valid but throttled" result.Valid = true // Key is valid, just throttled default: body, _ := io.ReadAll(resp.Body) result.ErrorMessage = fmt.Sprintf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } return result, nil } // Model represents a model returned by the /models endpoint. type Model struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created,omitempty"` OwnedBy string `json:"owned_by,omitempty"` } // ModelsResponse represents the response from the OpenAI-compatible /models endpoint. type ModelsResponse struct { Data []Model `json:"data"` } // ModelInfo contains detailed information about an available model. type ModelInfo struct { ID string `json:"id"` Provider string `json:"provider"` Description string `json:"description,omitempty"` Context int `json:"context_window,omitempty"` } // ListModels lists available models for a provider. func (c *Client) ListModels(ctx context.Context, provider Provider, apiKey string) ([]Model, error) { _, envVar, baseURL := provider.Info() if baseURL == "" { return nil, fmt.Errorf("unknown provider: %s", provider) } if apiKey == "" { apiKey = os.Getenv(envVar) } if apiKey == "" { return nil, fmt.Errorf("%s API key not set (set %s)", provider, envVar) } url := fmt.Sprintf("%s/models", baseURL) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } c.setAuthHeaders(req, provider, apiKey) resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } var modelsResp ModelsResponse if err := json.Unmarshal(body, &modelsResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } return modelsResp.Data, nil } // ValidateAll validates API keys for all providers in the fallback chain. // Returns a map of provider to validation result. func (c *Client) ValidateAll(ctx context.Context, selection *ProviderSelection, apiKeys map[Provider]string) map[Provider]*ValidationResult { results := make(map[Provider]*ValidationResult) for _, provider := range selection.FallbackChain { apiKey := apiKeys[provider] result, _ := c.ValidateAPIKey(ctx, provider, apiKey) results[provider] = result } return results } // setAuthHeaders sets authentication headers for a request based on the provider. func (c *Client) setAuthHeaders(req *http.Request, provider Provider, apiKey string) { switch provider { case ProviderZAI: req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) case ProviderVenice: req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) case ProviderOpenRouter: req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) // OpenRouter requires these additional headers req.Header.Set("HTTP-Referer", "https://github.com/openboatmobile/obm") req.Header.Set("X-Title", "OpenBoatMobile") } } // FormatValidationResults formats validation results for display. func FormatValidationResults(results map[Provider]*ValidationResult) string { var sb strings.Builder sb.WriteString("API Key Validation Results:\n\n") for _, provider := range SortedProviders() { result, ok := results[provider] if !ok { continue } name, _, _ := provider.Info() status := "INVALID" if result.Valid { status = "VALID" } fmt.Fprintf(&sb, " %s: %s", name, status) if result.Valid && result.ModelCount > 0 { fmt.Fprintf(&sb, " (%d models)", result.ModelCount) } if result.Latency > 0 { fmt.Fprintf(&sb, " [%dms]", result.Latency) } if result.ErrorMessage != "" { fmt.Fprintf(&sb, " - %s", result.ErrorMessage) } sb.WriteString("\n") } return sb.String() } // FormatModelList formats a model list for display. func FormatModelList(models []Model, provider Provider) string { var sb strings.Builder name, _, _ := provider.Info() fmt.Fprintf(&sb, "Available models for %s:\n\n", name) if len(models) == 0 { sb.WriteString(" No models found\n") return sb.String() } for _, model := range models { fmt.Fprintf(&sb, " - %s", model.ID) if model.OwnedBy != "" { fmt.Fprintf(&sb, " (%s)", model.OwnedBy) } sb.WriteString("\n") } fmt.Fprintf(&sb, "\nTotal: %d models\n", len(models)) return sb.String() }