Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pkg/config/model_alias_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ func TestResolveModelAliases(t *testing.T) {
mockData := &modelsdev.Database{
Providers: map[string]modelsdev.Provider{
"anthropic": {
ID: "anthropic",
Name: "Anthropic",
Models: map[string]modelsdev.Model{
"claude-sonnet-4-5": {Name: "Claude Sonnet 4.5 (latest)"},
"claude-sonnet-4-5-20250929": {Name: "Claude Sonnet 4.5"},
Expand Down
122 changes: 32 additions & 90 deletions pkg/modelsdev/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ type Store struct {
cacheFile string
mu sync.Mutex
db *Database
etag string // ETag from last successful fetch, used for conditional requests
}

// singleton holds the process-wide Store instance. It is initialised lazily
// on the first call to NewStore. All subsequent calls return the same value.
var singleton = sync.OnceValues(func() (*Store, error) {
// NewStore returns the process-wide singleton Store.
//
// The database is loaded lazily on the first call to GetDatabase and
// then cached in memory so that every caller shares one copy.
// The first call creates the cache directory if it does not exist.
var NewStore = sync.OnceValues(func() (*Store, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user home directory: %w", err)
Expand All @@ -52,15 +54,6 @@ var singleton = sync.OnceValues(func() (*Store, error) {
}, nil
})

// NewStore returns the process-wide singleton Store.
//
// The database is loaded lazily on the first call to GetDatabase and
// then cached in memory so that every caller shares one copy.
// The first call creates the cache directory if it does not exist.
func NewStore() (*Store, error) {
return singleton()
}

// NewDatabaseStore creates a Store pre-populated with the given database.
// The returned store serves data entirely from memory and never fetches
// from the network or touches the filesystem, making it suitable for
Expand All @@ -78,18 +71,17 @@ func (s *Store) GetDatabase(ctx context.Context) (*Database, error) {
return s.db, nil
}

db, etag, err := loadDatabase(ctx, s.cacheFile)
db, err := loadDatabase(ctx, s.cacheFile)
if err != nil {
return nil, err
}

s.db = db
s.etag = etag
return db, nil
}

// GetProvider returns a specific provider by ID.
func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) {
// getProvider returns a specific provider by ID.
func (s *Store) getProvider(ctx context.Context, providerID string) (*Provider, error) {
db, err := s.GetDatabase(ctx)
if err != nil {
return nil, err
Expand All @@ -112,30 +104,23 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
providerID := parts[0]
modelID := parts[1]

provider, err := s.GetProvider(ctx, providerID)
provider, err := s.getProvider(ctx, providerID)
if err != nil {
return nil, err
}

model, exists := provider.Models[modelID]
if !exists {
// For amazon-bedrock, try stripping region/inference profile prefixes
// Bedrock uses prefixes for cross-region inference profiles,
// but models.dev stores models without these prefixes.
//
// Strip known region prefixes and retry lookup.
if providerID == "amazon-bedrock" {
if before, after, ok := strings.Cut(modelID, "."); ok {
possibleRegionPrefix := before
if isBedrockRegionPrefix(possibleRegionPrefix) {
normalizedModelID := after
model, exists = provider.Models[normalizedModelID]
if exists {
return &model, nil
}
}
}

// For amazon-bedrock, try stripping region/inference profile prefixes.
// Bedrock uses prefixes for cross-region inference profiles,
// but models.dev stores models without these prefixes.
if !exists && providerID == "amazon-bedrock" {
if prefix, after, ok := strings.Cut(modelID, "."); ok && bedrockRegionPrefixes[prefix] {
model, exists = provider.Models[after]
}
}

if !exists {
return nil, fmt.Errorf("model %q not found in provider %q", modelID, providerID)
}

Expand All @@ -144,12 +129,11 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {

// loadDatabase loads the database from the local cache file or
// falls back to fetching from the models.dev API.
// It returns the database and the ETag associated with the data.
func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, error) {
func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) {
// Try to load from cache first
cached, err := loadFromCache(cacheFile)
if err == nil && time.Since(cached.LastRefresh) < refreshInterval {
return &cached.Database, cached.ETag, nil
return &cached.Database, nil
}

// Cache is stale or doesn't exist — try a conditional fetch with the ETag.
Expand All @@ -163,9 +147,9 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err
// If API fetch fails but we have cached data, use it regardless of age.
if cached != nil {
slog.Debug("API fetch failed, using stale cache", "error", fetchErr)
return &cached.Database, cached.ETag, nil
return &cached.Database, nil
}
return nil, "", fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr)
return nil, fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr)
}

// database is nil when the server returned 304 Not Modified.
Expand All @@ -175,15 +159,15 @@ func loadDatabase(ctx context.Context, cacheFile string) (*Database, string, err
if saveErr := saveToCache(cacheFile, &cached.Database, cached.ETag); saveErr != nil {
slog.Warn("Failed to update cache timestamp", "error", saveErr)
}
return &cached.Database, cached.ETag, nil
return &cached.Database, nil
}

// Save the fresh data to cache.
if saveErr := saveToCache(cacheFile, database, newETag); saveErr != nil {
slog.Warn("Failed to save to cache", "error", saveErr)
}

return database, newETag, nil
return database, nil
}

// fetchFromAPI fetches the models.dev database.
Expand Down Expand Up @@ -230,7 +214,6 @@ func fetchFromAPI(ctx context.Context, etag string) (*Database, string, error) {

return &Database{
Providers: providers,
UpdatedAt: time.Now(),
}, newETag, nil
}

Expand All @@ -249,11 +232,9 @@ func loadFromCache(cacheFile string) (*CachedData, error) {
}

func saveToCache(cacheFile string, database *Database, etag string) error {
now := time.Now()
cached := CachedData{
Database: *database,
CachedAt: now,
LastRefresh: now,
LastRefresh: time.Now(),
ETag: etag,
}

Expand Down Expand Up @@ -286,8 +267,7 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str
return modelName
}

// Get the provider from the database
provider, err := s.GetProvider(ctx, providerID)
provider, err := s.getProvider(ctx, providerID)
if err != nil {
return modelName
}
Expand Down Expand Up @@ -319,46 +299,8 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str
// stores models without regional prefixes. AWS uses these for cross-region inference profiles.
// See: https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference.html
var bedrockRegionPrefixes = map[string]bool{
"us": true, // US region inference profile
"eu": true, // EU region inference profile
"apac": true, // Asia Pacific region inference profile
"global": true, // Global inference profile (routes to any available region)
}

// isBedrockRegionPrefix returns true if the prefix is a known Bedrock regional/inference profile prefix.
func isBedrockRegionPrefix(prefix string) bool {
return bedrockRegionPrefixes[prefix]
}

// ModelSupportsReasoning checks if the given model ID supports reasoning/thinking.
//
// This function implements fail-open semantics:
// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
// - If models.dev lookup fails for any reason, returns true (fail-open)
// - If lookup succeeds, returns the model's Reasoning field value
func ModelSupportsReasoning(ctx context.Context, modelID string) bool {
// Fail-open for empty model ID
if modelID == "" {
return true
}

// Fail-open if not in provider/model format
if !strings.Contains(modelID, "/") {
slog.Debug("Model ID not in provider/model format, assuming reasoning supported to allow user choice", "model_id", modelID)
return true
}

store, err := NewStore()
if err != nil {
slog.Debug("Failed to create modelsdev store, assuming reasoning supported to allow user choice", "error", err)
return true
}

model, err := store.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err)
return true
}

return model.Reasoning
"us": true,
"eu": true,
"apac": true,
"global": true,
}
27 changes: 5 additions & 22 deletions pkg/modelsdev/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,20 @@ import "time"
// Database represents the complete models.dev database
type Database struct {
Providers map[string]Provider `json:"providers"`
UpdatedAt time.Time `json:"updated_at"`
}

// Provider represents an AI model provider
type Provider struct {
ID string `json:"id"`
Name string `json:"name"`
Doc string `json:"doc,omitempty"`
API string `json:"api,omitempty"`
NPM string `json:"npm,omitempty"`
Env []string `json:"env,omitempty"`
Models map[string]Model `json:"models"`
}

// Model represents an AI model with its specifications and capabilities
type Model struct {
ID string `json:"id"`
Name string `json:"name"`
Family string `json:"family,omitempty"`
Attachment bool `json:"attachment"`
Reasoning bool `json:"reasoning"`
Temperature bool `json:"temperature"`
ToolCall bool `json:"tool_call"`
Knowledge string `json:"knowledge,omitempty"`
ReleaseDate string `json:"release_date"`
LastUpdated string `json:"last_updated"`
OpenWeights bool `json:"open_weights"`
Cost *Cost `json:"cost,omitempty"`
Limit Limit `json:"limit"`
Modalities Modalities `json:"modalities"`
Name string `json:"name"`
Family string `json:"family,omitempty"`
Cost *Cost `json:"cost,omitempty"`
Limit Limit `json:"limit"`
Modalities Modalities `json:"modalities"`
}

// Cost represents the pricing information for a model
Expand All @@ -60,7 +44,6 @@ type Modalities struct {
// CachedData represents the cached models.dev data with metadata
type CachedData struct {
Database Database `json:"database"`
CachedAt time.Time `json:"cached_at"`
LastRefresh time.Time `json:"last_refresh"`
ETag string `json:"etag,omitempty"`
}
14 changes: 0 additions & 14 deletions pkg/runtime/model_switcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,25 +244,20 @@ func TestBuildCatalogChoices(t *testing.T) {
db := &modelsdev.Database{
Providers: map[string]modelsdev.Provider{
"openai": {
ID: "openai",
Name: "OpenAI",
Models: map[string]modelsdev.Model{
"gpt-4o": {
ID: "gpt-4o",
Name: "GPT-4o",
Modalities: modelsdev.Modalities{
Output: []string{"text"},
},
},
"dall-e-3": {
ID: "dall-e-3",
Name: "DALL-E 3",
Modalities: modelsdev.Modalities{
Output: []string{"image"}, // Not a text model
},
},
"text-embedding-3-large": {
ID: "text-embedding-3-large",
Name: "Text Embedding 3 Large",
Family: "text-embedding",
Modalities: modelsdev.Modalities{
Expand All @@ -272,11 +267,8 @@ func TestBuildCatalogChoices(t *testing.T) {
},
},
"anthropic": {
ID: "anthropic",
Name: "Anthropic",
Models: map[string]modelsdev.Model{
"claude-sonnet-4-0": {
ID: "claude-sonnet-4-0",
Name: "Claude Sonnet 4",
Modalities: modelsdev.Modalities{
Output: []string{"text"},
Expand All @@ -285,11 +277,8 @@ func TestBuildCatalogChoices(t *testing.T) {
},
},
"unsupported": {
ID: "unsupported",
Name: "Unsupported Provider",
Models: map[string]modelsdev.Model{
"some-model": {
ID: "some-model",
Name: "Some Model",
Modalities: modelsdev.Modalities{
Output: []string{"text"},
Expand Down Expand Up @@ -348,11 +337,8 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) {
db := &modelsdev.Database{
Providers: map[string]modelsdev.Provider{
"openai": {
ID: "openai",
Name: "OpenAI",
Models: map[string]modelsdev.Model{
"gpt-4o": {
ID: "gpt-4o",
Name: "GPT-4o",
Modalities: modelsdev.Modalities{
Output: []string{"text"},
Expand Down
Loading