obm/internal/inference/client.go
MermaidMan 33d9a2cb2e deploy walkthrough, API validation, inference client, Hetzner provider
- Interactive deploy command with 8-step walkthrough:
  framework → provider → token → SSH → server → inference → tailscale → discord
- .env file generation from walkthrough config
- DeploymentConfig struct with framework-aware defaults
- Inference API client with validation for Venice, OpenRouter, OpenAI, Anthropic
- Hetzner Cloud provider: token validation, SSH key listing
- DotEnv parser/writer with schema validation
- Destroy command with confirmation prompt
- Validation subcommand for checking existing .env files
- All tests passing, go vet clean
2026-05-22 15:29:27 +00:00

287 lines
7.7 KiB
Go

// Package inference provides API client functionality for inference providers.
package inference
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
// Client provides HTTP client functionality forinference providers.
type Client struct {
httpClient *http.Client
timeout time.Duration
}
// NewClient creates a new inference client with default settings.
func NewClient() *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
timeout: 30 * time.Second,
}
}
// NewClientWithTimeout creates a new inference client with custom timeout.
func NewClientWithTimeout(timeout time.Duration) *Client {
return &Client{
httpClient: &http.Client{
Timeout: timeout,
},
timeout: timeout,
}
}
// ValidationResult contains the result of an API validation attempt.
type ValidationResult struct {
Provider Provider `json:"provider"`
Valid bool `json:"valid"`
ErrorMessage string `json:"error_message,omitempty"`
ModelCount int `json:"model_count,omitempty"`
Latency int64 `json:"latency_ms"`
}
// ValidateAPIKey validates an API key for a provider by making a test request.
// Returns true if the API key is valid, false otherwise.
func (c *Client) ValidateAPIKey(ctx context.Context, provider Provider, apiKey string) (*ValidationResult, error) {
start := time.Now()
result := &ValidationResult{
Provider: provider,
}
// Get provider info
name, envVar, baseURL := provider.Info()
if baseURL == "" {
result.ErrorMessage = fmt.Sprintf("unknown provider: %s", provider)
return result, fmt.Errorf("unknown provider: %s", provider)
}
// Use provided API key orfall back to environment variable
if apiKey == "" {
apiKey = os.Getenv(envVar)
}
if apiKey == "" {
result.ErrorMessage = fmt.Sprintf("%s API key not set (set %s)", name, envVar)
return result, nil
}
// Make a test request to list models (validates auth without consuming tokens)
url := fmt.Sprintf("%s/models", baseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
result.ErrorMessage = fmt.Sprintf("failed to create request: %v", err)
return result, err
}
// Set auth headers based on provider
c.setAuthHeaders(req, provider, apiKey)
resp, err := c.httpClient.Do(req)
if err != nil {
result.ErrorMessage = fmt.Sprintf("request failed: %v", err)
result.Latency = time.Since(start).Milliseconds()
return result, nil
}
defer resp.Body.Close()
result.Latency = time.Since(start).Milliseconds()
if resp.StatusCode == http.StatusOK {
// Parse response to count models
body, err := io.ReadAll(resp.Body)
if err == nil {
var modelsResp ModelsResponse
if json.Unmarshal(body, &modelsResp) == nil {
result.ModelCount = len(modelsResp.Data)
}
}
result.Valid = true
return result, nil
}
// Handle specific error codes
switch resp.StatusCode {
case http.StatusUnauthorized:
result.ErrorMessage = "invalid API key"
case http.StatusForbidden:
result.ErrorMessage = "API key lacks required permissions"
case http.StatusTooManyRequests:
result.ErrorMessage = "rate limited - key is valid but throttled"
result.Valid = true // Key is valid, just throttled
default:
body, _ := io.ReadAll(resp.Body)
result.ErrorMessage = fmt.Sprintf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
return result, nil
}
// Model represents a model returned by the /models endpoint.
type Model struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created,omitempty"`
OwnedBy string `json:"owned_by,omitempty"`
}
// ModelsResponse represents the response from the OpenAI-compatible /models endpoint.
type ModelsResponse struct {
Data []Model `json:"data"`
}
// ModelInfo contains detailed information about an available model.
type ModelInfo struct {
ID string `json:"id"`
Provider string `json:"provider"`
Description string `json:"description,omitempty"`
Context int `json:"context_window,omitempty"`
}
// ListModels lists available models for a provider.
func (c *Client) ListModels(ctx context.Context, provider Provider, apiKey string) ([]Model, error) {
_, envVar, baseURL := provider.Info()
if baseURL == "" {
return nil, fmt.Errorf("unknown provider: %s", provider)
}
if apiKey == "" {
apiKey = os.Getenv(envVar)
}
if apiKey == "" {
return nil, fmt.Errorf("%s API key not set (set %s)", provider, envVar)
}
url := fmt.Sprintf("%s/models", baseURL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
c.setAuthHeaders(req, provider, apiKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var modelsResp ModelsResponse
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return modelsResp.Data, nil
}
// ValidateAll validates API keys for all providers in the fallback chain.
// Returns a map of provider to validation result.
func (c *Client) ValidateAll(ctx context.Context, selection *ProviderSelection, apiKeys map[Provider]string) map[Provider]*ValidationResult {
results := make(map[Provider]*ValidationResult)
for _, provider := range selection.FallbackChain {
apiKey := apiKeys[provider]
result, _ := c.ValidateAPIKey(ctx, provider, apiKey)
results[provider] = result
}
return results
}
// setAuthHeaders sets authentication headers for a request based on the provider.
func (c *Client) setAuthHeaders(req *http.Request, provider Provider, apiKey string) {
switch provider {
case ProviderZAI:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
case ProviderVenice:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
case ProviderOpenRouter:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
// OpenRouter requires these additional headers
req.Header.Set("HTTP-Referer", "https://github.com/openboatmobile/obm")
req.Header.Set("X-Title", "OpenBoatMobile")
}
}
// FormatValidationResults formats validation results for display.
func FormatValidationResults(results map[Provider]*ValidationResult) string {
var sb strings.Builder
sb.WriteString("API Key Validation Results:\n\n")
for _, provider := range SortedProviders() {
result, ok := results[provider]
if !ok {
continue
}
name, _, _ := provider.Info()
status := "INVALID"
if result.Valid {
status = "VALID"
}
fmt.Fprintf(&sb, " %s: %s", name, status)
if result.Valid && result.ModelCount > 0 {
fmt.Fprintf(&sb, " (%d models)", result.ModelCount)
}
if result.Latency > 0 {
fmt.Fprintf(&sb, " [%dms]", result.Latency)
}
if result.ErrorMessage != "" {
fmt.Fprintf(&sb, " - %s", result.ErrorMessage)
}
sb.WriteString("\n")
}
return sb.String()
}
// FormatModelList formats a model list for display.
func FormatModelList(models []Model, provider Provider) string {
var sb strings.Builder
name, _, _ := provider.Info()
fmt.Fprintf(&sb, "Available models for %s:\n\n", name)
if len(models) == 0 {
sb.WriteString(" No models found\n")
return sb.String()
}
for _, model := range models {
fmt.Fprintf(&sb, " - %s", model.ID)
if model.OwnedBy != "" {
fmt.Fprintf(&sb, " (%s)", model.OwnedBy)
}
sb.WriteString("\n")
}
fmt.Fprintf(&sb, "\nTotal: %d models\n", len(models))
return sb.String()
}