- Interactive deploy command with 8-step walkthrough: framework → provider → token → SSH → server → inference → tailscale → discord - .env file generation from walkthrough config - DeploymentConfig struct with framework-aware defaults - Inference API client with validation for Venice, OpenRouter, OpenAI, Anthropic - Hetzner Cloud provider: token validation, SSH key listing - DotEnv parser/writer with schema validation - Destroy command with confirmation prompt - Validation subcommand for checking existing .env files - All tests passing, go vet clean
287 lines
7.7 KiB
Go
287 lines
7.7 KiB
Go
// Package inference provides API client functionality for inference providers.
|
|
package inference
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Client provides HTTP client functionality forinference providers.
|
|
type Client struct {
|
|
httpClient *http.Client
|
|
timeout time.Duration
|
|
}
|
|
|
|
// NewClient creates a new inference client with default settings.
|
|
func NewClient() *Client {
|
|
return &Client{
|
|
httpClient: &http.Client{
|
|
Timeout: 30 * time.Second,
|
|
},
|
|
timeout: 30 * time.Second,
|
|
}
|
|
}
|
|
|
|
// NewClientWithTimeout creates a new inference client with custom timeout.
|
|
func NewClientWithTimeout(timeout time.Duration) *Client {
|
|
return &Client{
|
|
httpClient: &http.Client{
|
|
Timeout: timeout,
|
|
},
|
|
timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// ValidationResult contains the result of an API validation attempt.
|
|
type ValidationResult struct {
|
|
Provider Provider `json:"provider"`
|
|
Valid bool `json:"valid"`
|
|
ErrorMessage string `json:"error_message,omitempty"`
|
|
ModelCount int `json:"model_count,omitempty"`
|
|
Latency int64 `json:"latency_ms"`
|
|
}
|
|
|
|
// ValidateAPIKey validates an API key for a provider by making a test request.
|
|
// Returns true if the API key is valid, false otherwise.
|
|
func (c *Client) ValidateAPIKey(ctx context.Context, provider Provider, apiKey string) (*ValidationResult, error) {
|
|
start := time.Now()
|
|
|
|
result := &ValidationResult{
|
|
Provider: provider,
|
|
}
|
|
|
|
// Get provider info
|
|
name, envVar, baseURL := provider.Info()
|
|
if baseURL == "" {
|
|
result.ErrorMessage = fmt.Sprintf("unknown provider: %s", provider)
|
|
return result, fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
// Use provided API key orfall back to environment variable
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv(envVar)
|
|
}
|
|
|
|
if apiKey == "" {
|
|
result.ErrorMessage = fmt.Sprintf("%s API key not set (set %s)", name, envVar)
|
|
return result, nil
|
|
}
|
|
|
|
// Make a test request to list models (validates auth without consuming tokens)
|
|
url := fmt.Sprintf("%s/models", baseURL)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
result.ErrorMessage = fmt.Sprintf("failed to create request: %v", err)
|
|
return result, err
|
|
}
|
|
|
|
// Set auth headers based on provider
|
|
c.setAuthHeaders(req, provider, apiKey)
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
result.ErrorMessage = fmt.Sprintf("request failed: %v", err)
|
|
result.Latency = time.Since(start).Milliseconds()
|
|
return result, nil
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
result.Latency = time.Since(start).Milliseconds()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
// Parse response to count models
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err == nil {
|
|
var modelsResp ModelsResponse
|
|
if json.Unmarshal(body, &modelsResp) == nil {
|
|
result.ModelCount = len(modelsResp.Data)
|
|
}
|
|
}
|
|
result.Valid = true
|
|
return result, nil
|
|
}
|
|
|
|
// Handle specific error codes
|
|
switch resp.StatusCode {
|
|
case http.StatusUnauthorized:
|
|
result.ErrorMessage = "invalid API key"
|
|
case http.StatusForbidden:
|
|
result.ErrorMessage = "API key lacks required permissions"
|
|
case http.StatusTooManyRequests:
|
|
result.ErrorMessage = "rate limited - key is valid but throttled"
|
|
result.Valid = true // Key is valid, just throttled
|
|
default:
|
|
body, _ := io.ReadAll(resp.Body)
|
|
result.ErrorMessage = fmt.Sprintf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// Model represents a model returned by the /models endpoint.
|
|
type Model struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created,omitempty"`
|
|
OwnedBy string `json:"owned_by,omitempty"`
|
|
}
|
|
|
|
// ModelsResponse represents the response from the OpenAI-compatible /models endpoint.
|
|
type ModelsResponse struct {
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
// ModelInfo contains detailed information about an available model.
|
|
type ModelInfo struct {
|
|
ID string `json:"id"`
|
|
Provider string `json:"provider"`
|
|
Description string `json:"description,omitempty"`
|
|
Context int `json:"context_window,omitempty"`
|
|
}
|
|
|
|
// ListModels lists available models for a provider.
|
|
func (c *Client) ListModels(ctx context.Context, provider Provider, apiKey string) ([]Model, error) {
|
|
_, envVar, baseURL := provider.Info()
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
if apiKey == "" {
|
|
apiKey = os.Getenv(envVar)
|
|
}
|
|
|
|
if apiKey == "" {
|
|
return nil, fmt.Errorf("%s API key not set (set %s)", provider, envVar)
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/models", baseURL)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
c.setAuthHeaders(req, provider, apiKey)
|
|
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
}
|
|
|
|
var modelsResp ModelsResponse
|
|
if err := json.Unmarshal(body, &modelsResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
return modelsResp.Data, nil
|
|
}
|
|
|
|
// ValidateAll validates API keys for all providers in the fallback chain.
|
|
// Returns a map of provider to validation result.
|
|
func (c *Client) ValidateAll(ctx context.Context, selection *ProviderSelection, apiKeys map[Provider]string) map[Provider]*ValidationResult {
|
|
results := make(map[Provider]*ValidationResult)
|
|
|
|
for _, provider := range selection.FallbackChain {
|
|
apiKey := apiKeys[provider]
|
|
result, _ := c.ValidateAPIKey(ctx, provider, apiKey)
|
|
results[provider] = result
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// setAuthHeaders sets authentication headers for a request based on the provider.
|
|
func (c *Client) setAuthHeaders(req *http.Request, provider Provider, apiKey string) {
|
|
switch provider {
|
|
case ProviderZAI:
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
|
case ProviderVenice:
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
|
case ProviderOpenRouter:
|
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
|
// OpenRouter requires these additional headers
|
|
req.Header.Set("HTTP-Referer", "https://github.com/openboatmobile/obm")
|
|
req.Header.Set("X-Title", "OpenBoatMobile")
|
|
}
|
|
}
|
|
|
|
// FormatValidationResults formats validation results for display.
|
|
func FormatValidationResults(results map[Provider]*ValidationResult) string {
|
|
var sb strings.Builder
|
|
|
|
sb.WriteString("API Key Validation Results:\n\n")
|
|
|
|
for _, provider := range SortedProviders() {
|
|
result, ok := results[provider]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
name, _, _ := provider.Info()
|
|
status := "INVALID"
|
|
if result.Valid {
|
|
status = "VALID"
|
|
}
|
|
|
|
fmt.Fprintf(&sb, " %s: %s", name, status)
|
|
|
|
if result.Valid && result.ModelCount > 0 {
|
|
fmt.Fprintf(&sb, " (%d models)", result.ModelCount)
|
|
}
|
|
|
|
if result.Latency > 0 {
|
|
fmt.Fprintf(&sb, " [%dms]", result.Latency)
|
|
}
|
|
|
|
if result.ErrorMessage != "" {
|
|
fmt.Fprintf(&sb, " - %s", result.ErrorMessage)
|
|
}
|
|
|
|
sb.WriteString("\n")
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
// FormatModelList formats a model list for display.
|
|
func FormatModelList(models []Model, provider Provider) string {
|
|
var sb strings.Builder
|
|
|
|
name, _, _ := provider.Info()
|
|
fmt.Fprintf(&sb, "Available models for %s:\n\n", name)
|
|
|
|
if len(models) == 0 {
|
|
sb.WriteString(" No models found\n")
|
|
return sb.String()
|
|
}
|
|
|
|
for _, model := range models {
|
|
fmt.Fprintf(&sb, " - %s", model.ID)
|
|
if model.OwnedBy != "" {
|
|
fmt.Fprintf(&sb, " (%s)", model.OwnedBy)
|
|
}
|
|
sb.WriteString("\n")
|
|
}
|
|
|
|
fmt.Fprintf(&sb, "\nTotal: %d models\n", len(models))
|
|
|
|
return sb.String()
|
|
}
|