mcp-framework/cli/resolve.go

283 lines
6.1 KiB
Go

package cli
import (
"errors"
"fmt"
"io"
"slices"
"strings"
)
type ValueSource string
const (
SourceFlag ValueSource = "flag"
SourceEnv ValueSource = "env"
SourceConfig ValueSource = "config"
SourceSecret ValueSource = "secret"
SourceDefault ValueSource = "default"
)
var DefaultResolutionOrder = []ValueSource{
SourceFlag,
SourceEnv,
SourceConfig,
SourceSecret,
}
var ErrInvalidResolverInput = errors.New("invalid resolver input")
type FieldSpec struct {
Name string
Required bool
DefaultValue string
Sources []ValueSource
FlagKey string
EnvKey string
ConfigKey string
SecretKey string
}
type LookupFunc func(source ValueSource, key string) (string, bool, error)
type ResolveOptions struct {
Fields []FieldSpec
Order []ValueSource
Lookup LookupFunc
}
type ResolvedField struct {
Name string
Value string
Source ValueSource
Found bool
}
type Resolution struct {
Fields []ResolvedField
}
func (r Resolution) Get(name string) (ResolvedField, bool) {
needle := strings.TrimSpace(name)
for _, field := range r.Fields {
if field.Name == needle {
return field, true
}
}
return ResolvedField{}, false
}
type MissingRequiredValuesError struct {
Fields []string
}
func (e *MissingRequiredValuesError) Error() string {
return fmt.Sprintf("missing required configuration values: %s", strings.Join(e.Fields, ", "))
}
type SourceLookupError struct {
Field string
Source ValueSource
Key string
Err error
}
func (e *SourceLookupError) Error() string {
return fmt.Sprintf(
"resolve %q from %q (key %q): %v",
e.Field,
e.Source,
e.Key,
e.Err,
)
}
func (e *SourceLookupError) Unwrap() error {
return e.Err
}
type StaticLookup struct {
Flags map[string]string
Env map[string]string
Config map[string]string
Secrets map[string]string
}
func (l StaticLookup) Lookup(source ValueSource, key string) (string, bool, error) {
var values map[string]string
switch source {
case SourceFlag:
values = l.Flags
case SourceEnv:
values = l.Env
case SourceConfig:
values = l.Config
case SourceSecret:
values = l.Secrets
case SourceDefault:
return "", false, fmt.Errorf("%w: source %q is reserved", ErrInvalidResolverInput, source)
default:
return "", false, fmt.Errorf("%w: unknown source %q", ErrInvalidResolverInput, source)
}
value, ok := values[key]
return value, ok, nil
}
func ResolveFields(options ResolveOptions) (Resolution, error) {
if options.Lookup == nil {
return Resolution{}, fmt.Errorf("%w: lookup is required", ErrInvalidResolverInput)
}
globalOrder, err := normalizeOrder(options.Order, DefaultResolutionOrder)
if err != nil {
return Resolution{}, fmt.Errorf("%w: %v", ErrInvalidResolverInput, err)
}
resolution := Resolution{
Fields: make([]ResolvedField, 0, len(options.Fields)),
}
missingRequired := make([]string, 0)
seenNames := make(map[string]struct{}, len(options.Fields))
for i, spec := range options.Fields {
name := strings.TrimSpace(spec.Name)
if name == "" {
return Resolution{}, fmt.Errorf("%w: field at index %d has empty name", ErrInvalidResolverInput, i)
}
if _, exists := seenNames[name]; exists {
return Resolution{}, fmt.Errorf("%w: duplicate field name %q", ErrInvalidResolverInput, name)
}
seenNames[name] = struct{}{}
order := globalOrder
if len(spec.Sources) > 0 {
order, err = normalizeOrder(spec.Sources, globalOrder)
if err != nil {
return Resolution{}, fmt.Errorf("%w: field %q: %v", ErrInvalidResolverInput, name, err)
}
}
field := ResolvedField{Name: name}
for _, source := range order {
key := spec.keyFor(source)
if key == "" {
continue
}
value, found, err := options.Lookup(source, key)
if err != nil {
return resolution, &SourceLookupError{
Field: name,
Source: source,
Key: key,
Err: err,
}
}
trimmed := strings.TrimSpace(value)
if found && trimmed != "" {
field.Value = trimmed
field.Source = source
field.Found = true
break
}
}
if !field.Found {
defaultValue := strings.TrimSpace(spec.DefaultValue)
if defaultValue != "" {
field.Value = defaultValue
field.Source = SourceDefault
field.Found = true
}
}
if spec.Required && !field.Found {
missingRequired = append(missingRequired, name)
}
resolution.Fields = append(resolution.Fields, field)
}
if len(missingRequired) > 0 {
return resolution, &MissingRequiredValuesError{Fields: missingRequired}
}
return resolution, nil
}
func RenderResolutionProvenance(w io.Writer, resolution Resolution) error {
for _, field := range resolution.Fields {
source := "missing"
if field.Found {
source = string(field.Source)
}
if _, err := fmt.Fprintf(w, "%s: %s\n", field.Name, source); err != nil {
return err
}
}
return nil
}
func normalizeOrder(input []ValueSource, fallback []ValueSource) ([]ValueSource, error) {
order := input
if len(order) == 0 {
order = fallback
}
result := make([]ValueSource, 0, len(order))
seen := make(map[ValueSource]struct{}, len(order))
for _, source := range order {
if !isKnownSource(source) {
return nil, fmt.Errorf("unknown source %q", source)
}
if source == SourceDefault {
return nil, fmt.Errorf("source %q cannot be used in resolution order", source)
}
if _, exists := seen[source]; exists {
continue
}
seen[source] = struct{}{}
result = append(result, source)
}
if len(result) == 0 {
return nil, fmt.Errorf("resolution order is empty")
}
return slices.Clone(result), nil
}
func isKnownSource(source ValueSource) bool {
switch source {
case SourceFlag, SourceEnv, SourceConfig, SourceSecret, SourceDefault:
return true
default:
return false
}
}
func (s FieldSpec) keyFor(source ValueSource) string {
switch source {
case SourceFlag:
return fallbackKey(s.FlagKey, s.Name)
case SourceEnv:
return fallbackKey(s.EnvKey, s.Name)
case SourceConfig:
return fallbackKey(s.ConfigKey, s.Name)
case SourceSecret:
return fallbackKey(s.SecretKey, s.Name)
default:
return ""
}
}
func fallbackKey(explicit, fallback string) string {
if key := strings.TrimSpace(explicit); key != "" {
return key
}
return strings.TrimSpace(fallback)
}