mcp-framework/cli/setup.go

531 lines
12 KiB
Go
Raw Permalink Normal View History

package cli
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"slices"
"strings"
"golang.org/x/term"
)
type SetupFieldType string
const (
SetupFieldString SetupFieldType = "string"
SetupFieldURL SetupFieldType = "url"
SetupFieldSecret SetupFieldType = "secret"
SetupFieldBool SetupFieldType = "bool"
SetupFieldList SetupFieldType = "list"
)
var ErrInvalidSetupDefinition = errors.New("invalid setup definition")
type SetupField struct {
Name string
Label string
Type SetupFieldType
Required bool
Default string
ExistingSecret string
ListSeparator string
Normalize func(string) string
Validate func(string) error
ValidateBool func(bool) error
ValidateList func([]string) error
}
type SetupOptions struct {
Fields []SetupField
Stdin *os.File
Stdout io.Writer
}
type SetupValue struct {
Type SetupFieldType
String string
Bool bool
List []string
Set bool
KeptStoredSecret bool
}
type SetupResultField struct {
Name string
Value SetupValue
}
type SetupResult struct {
Fields []SetupResultField
}
func (r SetupResult) Get(name string) (SetupValue, bool) {
needle := strings.TrimSpace(name)
for _, field := range r.Fields {
if field.Name == needle {
return field.Value, true
}
}
return SetupValue{}, false
}
type SetupValidationError struct {
Field string
Label string
Message string
}
func (e *SetupValidationError) Error() string {
label := strings.TrimSpace(e.Label)
if label == "" {
label = strings.TrimSpace(e.Field)
}
if label == "" {
return strings.TrimSpace(e.Message)
}
return fmt.Sprintf("%s: %s", label, strings.TrimSpace(e.Message))
}
type normalizedSetupField struct {
Name string
Label string
Type SetupFieldType
Required bool
DefaultString string
DefaultBool *bool
DefaultList []string
ExistingSecret string
ListSeparator string
Normalize func(string) string
Validate func(string) error
ValidateBool func(bool) error
ValidateList func([]string) error
}
func RunSetup(options SetupOptions) (SetupResult, error) {
stdin, stdout := normalizeSetupIO(options)
fields, err := normalizeSetupFields(options.Fields)
if err != nil {
return SetupResult{}, err
}
reader := bufio.NewReader(stdin)
fd := int(stdin.Fd())
isTTY := term.IsTerminal(fd)
result := SetupResult{
Fields: make([]SetupResultField, 0, len(fields)),
}
for _, field := range fields {
value, err := promptSetupField(reader, stdin, stdout, fd, isTTY, field)
if err != nil {
return result, err
}
result.Fields = append(result.Fields, SetupResultField{
Name: field.Name,
Value: value,
})
}
return result, nil
}
func normalizeSetupIO(options SetupOptions) (*os.File, io.Writer) {
stdin := options.Stdin
if stdin == nil {
stdin = os.Stdin
}
stdout := options.Stdout
if stdout == nil {
stdout = os.Stdout
}
return stdin, stdout
}
func normalizeSetupFields(fields []SetupField) ([]normalizedSetupField, error) {
normalized := make([]normalizedSetupField, 0, len(fields))
seenNames := make(map[string]struct{}, len(fields))
for i, field := range fields {
name := strings.TrimSpace(field.Name)
if name == "" {
return nil, fmt.Errorf("%w: field at index %d has empty name", ErrInvalidSetupDefinition, i)
}
if _, exists := seenNames[name]; exists {
return nil, fmt.Errorf("%w: duplicate field name %q", ErrInvalidSetupDefinition, name)
}
seenNames[name] = struct{}{}
if !isKnownSetupFieldType(field.Type) {
return nil, fmt.Errorf("%w: field %q uses unknown type %q", ErrInvalidSetupDefinition, name, field.Type)
}
label := strings.TrimSpace(field.Label)
if label == "" {
label = name
}
normalizer := field.Normalize
if normalizer == nil {
normalizer = strings.TrimSpace
}
listSeparator := field.ListSeparator
if listSeparator == "" {
listSeparator = ","
}
entry := normalizedSetupField{
Name: name,
Label: label,
Type: field.Type,
Required: field.Required,
ExistingSecret: strings.TrimSpace(field.ExistingSecret),
ListSeparator: listSeparator,
Normalize: normalizer,
Validate: field.Validate,
ValidateBool: field.ValidateBool,
ValidateList: field.ValidateList,
}
switch field.Type {
case SetupFieldString, SetupFieldURL, SetupFieldSecret:
entry.DefaultString = normalizer(field.Default)
if field.Type == SetupFieldURL && entry.DefaultString != "" {
if err := ValidateBaseURL(entry.DefaultString); err != nil {
return nil, fmt.Errorf("%w: field %q default URL is invalid: %v", ErrInvalidSetupDefinition, name, err)
}
}
case SetupFieldBool:
defaultRaw := strings.TrimSpace(field.Default)
if defaultRaw != "" {
defaultValue, err := parseBoolValue(defaultRaw)
if err != nil {
return nil, fmt.Errorf("%w: field %q default bool is invalid: %v", ErrInvalidSetupDefinition, name, err)
}
entry.DefaultBool = &defaultValue
}
case SetupFieldList:
defaultRaw := strings.TrimSpace(field.Default)
if defaultRaw != "" {
entry.DefaultList = splitSetupList(defaultRaw, listSeparator, normalizer)
}
}
normalized = append(normalized, entry)
}
return normalized, nil
}
func promptSetupField(
reader *bufio.Reader,
stdin *os.File,
stdout io.Writer,
fd int,
isTTY bool,
field normalizedSetupField,
) (SetupValue, error) {
for {
if err := renderSetupPrompt(stdout, field); err != nil {
return SetupValue{}, err
}
raw, err := readSetupInput(reader, stdout, fd, isTTY, field.Type)
if err != nil {
return SetupValue{}, fmt.Errorf("read %q: %w", field.Name, err)
}
value, validationErr := parseSetupValue(field, raw)
if validationErr == nil {
return value, nil
}
setupErr := &SetupValidationError{
Field: field.Name,
Label: field.Label,
Message: validationErr.Error(),
}
if !isTTY {
return SetupValue{}, setupErr
}
if _, err := fmt.Fprintf(stdout, "Invalid value for %s: %s\n", field.Label, validationErr.Error()); err != nil {
return SetupValue{}, err
}
}
}
func readSetupInput(
reader *bufio.Reader,
stdout io.Writer,
fd int,
isTTY bool,
fieldType SetupFieldType,
) (string, error) {
if fieldType == SetupFieldSecret && isTTY {
secret, err := term.ReadPassword(fd)
fmt.Fprintln(stdout)
if err != nil {
return "", err
}
return string(secret), nil
}
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return "", err
}
return line, nil
}
func parseSetupValue(field normalizedSetupField, raw string) (SetupValue, error) {
switch field.Type {
case SetupFieldString:
return parseSetupStringValue(field, raw)
case SetupFieldURL:
value, err := parseSetupStringValue(field, raw)
if err != nil {
return SetupValue{}, err
}
if value.Set {
if err := ValidateBaseURL(value.String); err != nil {
return SetupValue{}, fmt.Errorf("must be a valid URL with scheme and host")
}
}
return value, nil
case SetupFieldSecret:
return parseSetupSecretValue(field, raw)
case SetupFieldBool:
return parseSetupBoolValue(field, raw)
case SetupFieldList:
return parseSetupListValue(field, raw)
default:
return SetupValue{}, fmt.Errorf("unsupported field type %q", field.Type)
}
}
func parseSetupStringValue(field normalizedSetupField, raw string) (SetupValue, error) {
value := field.Normalize(raw)
set := value != ""
if !set && field.DefaultString != "" {
value = field.DefaultString
set = true
}
if field.Required && !set {
return SetupValue{}, fmt.Errorf("value is required")
}
if set && field.Validate != nil {
if err := field.Validate(value); err != nil {
return SetupValue{}, err
}
}
return SetupValue{
Type: field.Type,
String: value,
Set: set,
}, nil
}
func parseSetupSecretValue(field normalizedSetupField, raw string) (SetupValue, error) {
value := field.Normalize(raw)
set := value != ""
keptStored := false
if !set && field.ExistingSecret != "" {
value = field.ExistingSecret
set = true
keptStored = true
} else if !set && field.DefaultString != "" {
value = field.DefaultString
set = true
}
if field.Required && !set {
return SetupValue{}, fmt.Errorf("value is required")
}
if set && field.Validate != nil {
if err := field.Validate(value); err != nil {
return SetupValue{}, err
}
}
return SetupValue{
Type: field.Type,
String: value,
Set: set,
KeptStoredSecret: keptStored,
}, nil
}
func parseSetupBoolValue(field normalizedSetupField, raw string) (SetupValue, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
if field.DefaultBool != nil {
value := *field.DefaultBool
if field.ValidateBool != nil {
if err := field.ValidateBool(value); err != nil {
return SetupValue{}, err
}
}
return SetupValue{
Type: SetupFieldBool,
Bool: value,
Set: true,
}, nil
}
if field.Required {
return SetupValue{}, fmt.Errorf("value is required")
}
return SetupValue{
Type: SetupFieldBool,
Set: false,
}, nil
}
value, err := parseBoolValue(trimmed)
if err != nil {
return SetupValue{}, err
}
if field.ValidateBool != nil {
if err := field.ValidateBool(value); err != nil {
return SetupValue{}, err
}
}
return SetupValue{
Type: SetupFieldBool,
Bool: value,
Set: true,
}, nil
}
func parseSetupListValue(field normalizedSetupField, raw string) (SetupValue, error) {
trimmed := strings.TrimSpace(raw)
var list []string
set := false
if trimmed != "" {
list = splitSetupList(trimmed, field.ListSeparator, field.Normalize)
set = true
} else if len(field.DefaultList) > 0 {
list = slices.Clone(field.DefaultList)
set = true
}
if field.Required && len(list) == 0 {
return SetupValue{}, fmt.Errorf("value is required")
}
if field.ValidateList != nil {
if err := field.ValidateList(list); err != nil {
return SetupValue{}, err
}
}
return SetupValue{
Type: SetupFieldList,
List: list,
Set: set,
}, nil
}
func renderSetupPrompt(w io.Writer, field normalizedSetupField) error {
switch field.Type {
case SetupFieldSecret:
if field.ExistingSecret != "" {
_, err := fmt.Fprintf(w, "%s [stored, leave blank to keep]: ", field.Label)
return err
}
if field.DefaultString != "" {
_, err := fmt.Fprintf(w, "%s [%s]: ", field.Label, field.DefaultString)
return err
}
_, err := fmt.Fprintf(w, "%s: ", field.Label)
return err
case SetupFieldBool:
defaultLabel := "y/n"
if field.DefaultBool != nil {
if *field.DefaultBool {
defaultLabel = "Y/n"
} else {
defaultLabel = "y/N"
}
}
_, err := fmt.Fprintf(w, "%s [%s]: ", field.Label, defaultLabel)
return err
case SetupFieldList:
if len(field.DefaultList) > 0 {
_, err := fmt.Fprintf(
w,
"%s [%s]: ",
field.Label,
strings.Join(field.DefaultList, field.ListSeparator),
)
return err
}
_, err := fmt.Fprintf(w, "%s: ", field.Label)
return err
case SetupFieldString, SetupFieldURL:
if field.DefaultString != "" {
_, err := fmt.Fprintf(w, "%s [%s]: ", field.Label, field.DefaultString)
return err
}
_, err := fmt.Fprintf(w, "%s: ", field.Label)
return err
default:
_, err := fmt.Fprintf(w, "%s: ", field.Label)
return err
}
}
func splitSetupList(raw, separator string, normalize func(string) string) []string {
parts := strings.Split(raw, separator)
list := make([]string, 0, len(parts))
for _, part := range parts {
normalized := normalize(part)
if normalized == "" {
continue
}
list = append(list, normalized)
}
return list
}
func parseBoolValue(raw string) (bool, error) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1", "t", "true", "y", "yes", "on":
return true, nil
case "0", "f", "false", "n", "no", "off":
return false, nil
default:
return false, fmt.Errorf("must be one of: yes/no, y/n, true/false, 1/0")
}
}
func isKnownSetupFieldType(fieldType SetupFieldType) bool {
switch fieldType {
case SetupFieldString, SetupFieldURL, SetupFieldSecret, SetupFieldBool, SetupFieldList:
return true
default:
return false
}
}