diff --git a/internal/config/config.go b/internal/config/config.go index 1b24695c..f52a8084 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,8 +26,13 @@ func SetFileName(name string) { filename = name + // Validate file extension before creating the file + err := validateConfigFileExtension(filename) + checkErr(err) + file.CreateEmptyIfNotExists(filename) - configureViper(filename) + err = configureViper(filename) + checkErr(err) } func SetFileNameForTest(t *testing.T) { diff --git a/internal/config/viper.go b/internal/config/viper.go index 9409fd8d..5f19a24c 100644 --- a/internal/config/viper.go +++ b/internal/config/viper.go @@ -5,6 +5,10 @@ package config import ( "bytes" + "path/filepath" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/internal/pal" "github.com/spf13/viper" "gopkg.in/yaml.v2" @@ -56,16 +60,44 @@ func GetConfigFileUsed() string { return viper.ConfigFileUsed() } +// validateConfigFileExtension checks if the config file has a supported extension. +// It allows .yaml, .yml, and no extension (for default sqlconfig file). +// Returns an error if the extension is not supported. +func validateConfigFileExtension(configFile string) error { + ext := strings.ToLower(filepath.Ext(configFile)) + + // Allow no extension (for default sqlconfig file) + if ext == "" { + return nil + } + + // Allow .yaml and .yml extensions + if ext == ".yaml" || ext == ".yml" { + return nil + } + + // Return error for unsupported extensions + return localizer.Errorf( + "Configuration files must use YAML format with .yaml or .yml extension. The file '%s' has an unsupported extension '%s'.", + configFile, ext) +} + // configureViper initializes the Viper library with the given configuration file. // This function sets the configuration file type to "yaml" and sets the environment variable prefix to "SQLCMD". // It also sets the configuration file to use to the one provided as an argument to the function. // This function is intended to be called at the start of the application to configure Viper before any other code uses it. -func configureViper(configFile string) { +func configureViper(configFile string) error { if configFile == "" { panic("Must provide configFile") } + // Validate file extension + if err := validateConfigFileExtension(configFile); err != nil { + return err + } + viper.SetConfigType("yaml") viper.SetEnvPrefix("SQLCMD") viper.SetConfigFile(configFile) + return nil } diff --git a/internal/config/viper_test.go b/internal/config/viper_test.go index 3a606f6d..f192d018 100644 --- a/internal/config/viper_test.go +++ b/internal/config/viper_test.go @@ -14,6 +14,145 @@ func Test_configureViper(t *testing.T) { }) } +func Test_validateConfigFileExtension(t *testing.T) { + tests := []struct { + name string + filename string + wantErr bool + }{ + { + name: "valid yaml extension", + filename: "config.yaml", + wantErr: false, + }, + { + name: "valid yml extension", + filename: "config.yml", + wantErr: false, + }, + { + name: "no extension (default sqlconfig)", + filename: "sqlconfig", + wantErr: false, + }, + { + name: "no extension with path", + filename: "/home/user/.sqlcmd/sqlconfig", + wantErr: false, + }, + { + name: "invalid txt extension", + filename: "config.txt", + wantErr: true, + }, + { + name: "invalid json extension", + filename: "config.json", + wantErr: true, + }, + { + name: "invalid xml extension", + filename: "config.xml", + wantErr: true, + }, + { + name: "uppercase YAML extension", + filename: "config.YAML", + wantErr: false, + }, + { + name: "uppercase YML extension", + filename: "config.YML", + wantErr: false, + }, + { + name: "mixed case yaml extension", + filename: "config.Yaml", + wantErr: false, + }, + { + name: "multiple dots with valid extension", + filename: "my.config.yaml", + wantErr: false, + }, + { + name: "multiple dots with invalid extension", + filename: "my.config.txt", + wantErr: true, + }, + { + name: "backup file with valid extension", + filename: "config.backup.yaml", + wantErr: false, + }, + { + name: "backup file with invalid extension", + filename: "config.backup.txt", + wantErr: true, + }, + { + name: "hidden file with yaml extension", + filename: ".config.yaml", + wantErr: false, + }, + { + name: "hidden file with yml extension", + filename: ".config.yml", + wantErr: false, + }, + { + name: "hidden file with invalid extension", + filename: ".config.txt", + wantErr: true, + }, + { + name: "file with only dot and yaml", + filename: ".yaml", + wantErr: false, + }, + { + name: "file with only dot and yml", + filename: ".yml", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateConfigFileExtension(tt.filename) + if tt.wantErr { + assert.Error(t, err, "Expected error for filename: %s", tt.filename) + assert.Contains(t, err.Error(), "Configuration files must use YAML format") + } else { + assert.NoError(t, err, "Expected no error for filename: %s", tt.filename) + } + }) + } +} + +func Test_configureViper_withInvalidExtension(t *testing.T) { + err := configureViper("myconfig.txt") + assert.Error(t, err) + assert.Contains(t, err.Error(), "Configuration files must use YAML format") + assert.Contains(t, err.Error(), ".txt") +} + +func Test_configureViper_withValidExtensions(t *testing.T) { + testCases := []string{ + "config.yaml", + "config.yml", + "sqlconfig", + "/path/to/config.yaml", + } + + for _, filename := range testCases { + t.Run(filename, func(t *testing.T) { + err := configureViper(filename) + assert.NoError(t, err, "Expected no error for filename: %s", filename) + }) + } +} + func Test_Load(t *testing.T) { SetFileNameForTest(t) Clean()