Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions internal/cli/serverless/migration/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
}
definitionStr := string(definitionBytes)

sources, target, mode, err := parseMigrationDefinition(definitionStr)
sources, target, mode, importMode, err := parseMigrationDefinition(definitionStr)
if err != nil {
return err
}
Expand All @@ -97,6 +97,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
Sources: sources,
Target: target,
Mode: mode,
ImportMode: importMode,
}
return runMigrationPrecheck(ctx, d, clusterID, precheckBody, h)
}
Expand All @@ -106,6 +107,7 @@ func CreateCmd(h *internal.Helper) *cobra.Command {
Sources: sources,
Target: target,
Mode: mode,
ImportMode: importMode,
}

resp, err := d.CreateMigration(ctx, clusterID, createBody)
Expand Down Expand Up @@ -247,34 +249,44 @@ func shouldPrintPrecheckItem(status *pkgmigration.PrecheckItemStatus) bool {
}
}

func parseMigrationDefinition(value string) ([]pkgmigration.Source, pkgmigration.Target, pkgmigration.TaskMode, error) {
func parseMigrationDefinition(value string) ([]pkgmigration.Source, pkgmigration.Target, pkgmigration.TaskMode, *pkgmigration.ImportMode, error) {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return nil, pkgmigration.Target{}, "", errors.New("migration config is required; use --config-file")
return nil, pkgmigration.Target{}, "", nil, errors.New("migration config is required; use --config-file")
}
var payload struct {
Sources []pkgmigration.Source `json:"sources"`
Target *pkgmigration.Target `json:"target"`
Mode string `json:"mode"`
Sources []pkgmigration.Source `json:"sources"`
Target *pkgmigration.Target `json:"target"`
Mode string `json:"mode"`
ImportMode *string `json:"importMode"`
}
stdJson, err := standardizeJSON([]byte(trimmed))
if err != nil {
return nil, pkgmigration.Target{}, "", errors.Annotate(err, "invalid migration definition JSON")
return nil, pkgmigration.Target{}, "", nil, errors.Annotate(err, "invalid migration definition JSON")
}
if err := json.Unmarshal(stdJson, &payload); err != nil {
return nil, pkgmigration.Target{}, "", errors.Annotate(err, "invalid migration definition JSON")
return nil, pkgmigration.Target{}, "", nil, errors.Annotate(err, "invalid migration definition JSON")
}
if len(payload.Sources) == 0 {
return nil, pkgmigration.Target{}, "", errors.New("migration definition must include at least one source")
return nil, pkgmigration.Target{}, "", nil, errors.New("migration definition must include at least one source")
}
if payload.Target == nil {
return nil, pkgmigration.Target{}, "", errors.New("migration definition must include the target block")
return nil, pkgmigration.Target{}, "", nil, errors.New("migration definition must include the target block")
}
mode, err := parseMigrationMode(payload.Mode)
if err != nil {
return nil, pkgmigration.Target{}, "", err
return nil, pkgmigration.Target{}, "", nil, err
}
return payload.Sources, *payload.Target, mode, nil

importMode, err := parseImportMode(payload.ImportMode)
if err != nil {
return nil, pkgmigration.Target{}, "", nil, err
}
if mode == pkgmigration.TASKMODE_INCREMENTAL && importMode != nil {
return nil, pkgmigration.Target{}, "", nil, errors.New("importMode is only applicable for mode=ALL; remove importMode or switch to mode=ALL")
}

return payload.Sources, *payload.Target, mode, importMode, nil
}

func parseMigrationMode(value string) (pkgmigration.TaskMode, error) {
Expand All @@ -290,6 +302,30 @@ func parseMigrationMode(value string) (pkgmigration.TaskMode, error) {
return "", errors.Errorf("invalid mode %q, allowed values: %s", value, pkgmigration.AllowedTaskModeEnumValues)
}

func parseImportMode(raw *string) (*pkgmigration.ImportMode, error) {
if raw == nil {
return nil, nil
}
trimmed := strings.TrimSpace(*raw)
if trimmed == "" {
return nil, nil
}

normalized := strings.ToUpper(trimmed)
switch normalized {
case "LOGICAL":
normalized = "IMPORT_MODE_LOGICAL"
case "PHYSICAL":
normalized = "IMPORT_MODE_PHYSICAL"
}

mode := pkgmigration.ImportMode(normalized)
if slices.Contains(pkgmigration.AllowedImportModeEnumValues, mode) {
return &mode, nil
}
return nil, errors.Errorf("invalid importMode %q, allowed values: %s", trimmed, pkgmigration.AllowedImportModeEnumValues)
}

// standardizeJSON accepts JSON With Commas and Comments(JWCC) see
// https://nigeltao.github.io/blog/2021/json-with-commas-comments.html) and
// returns a standard JSON byte slice ready for json.Unmarshal.
Expand Down
17 changes: 16 additions & 1 deletion internal/cli/serverless/migration/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ func (suite *CreateMigrationSuite) TestCreateMigration() {
ctx,
clusterID,
mockTool.MatchedBy(func(body *pkgmigration.MigrationServiceCreateMigrationBody) bool {
hasImportMode := body != nil && body.ImportMode != nil && *body.ImportMode == pkgmigration.IMPORTMODE_IMPORT_MODE_LOGICAL
return body != nil &&
body.DisplayName == displayName &&
body.Mode == pkgmigration.TASKMODE_ALL &&
len(body.Sources) == 1 &&
body.Target.User == "migration_user"
body.Target.User == "migration_user" &&
hasImportMode
}),
).Return(&pkgmigration.Migration{MigrationId: aws.String(migrationID)}, nil)

Expand All @@ -98,6 +100,8 @@ func (suite *CreateMigrationSuite) TestCreateMigrationInvalidInputs() {
blankPath := suite.writeTempConfig(" ")
invalidJSONPath := suite.writeTempConfig("{invalid")
invalidModePath := suite.writeTempConfig(`{ "mode": "invalid", "target": {"user":"u","password":"p"}, "sources": [{"sourceType":"MYSQL","connProfile":{"connType":"PUBLIC","host":"h","port":3306,"user":"u","password":"p"}}] }`)
invalidImportModePath := suite.writeTempConfig(`{ "mode": "ALL", "importMode": "nope", "target": {"user":"u","password":"p"}, "sources": [{"sourceType":"MYSQL","connProfile":{"connType":"PUBLIC","host":"h","port":3306,"user":"u","password":"p"}}] }`)
importModeWithIncrementalPath := suite.writeTempConfig(`{ "mode": "INCREMENTAL", "importMode": "logical", "target": {"user":"u","password":"p"}, "sources": [{"sourceType":"MYSQL","connProfile":{"connType":"PUBLIC","host":"h","port":3306,"user":"u","password":"p"}}] }`)

tests := []struct {
name string
Expand All @@ -124,6 +128,16 @@ func (suite *CreateMigrationSuite) TestCreateMigrationInvalidInputs() {
args: []string{"--cluster-id", "c1", "--display-name", "name", "--config-file", invalidModePath},
errContains: "invalid mode",
},
{
name: "invalid importMode",
args: []string{"--cluster-id", "c1", "--display-name", "name", "--config-file", invalidImportModePath},
errContains: "invalid importMode",
},
{
name: "importMode with incremental mode",
args: []string{"--cluster-id", "c1", "--display-name", "name", "--config-file", importModeWithIncrementalPath},
errContains: "importMode is only applicable for mode=ALL",
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -152,6 +166,7 @@ func (suite *CreateMigrationSuite) writeTempConfig(content string) string {
func validMigrationConfig() string {
return `{
"mode": "ALL",
"importMode": "IMPORT_MODE_LOGICAL",
"target": {
"user": "migration_user",
"password": "Passw0rd!"
Expand Down
4 changes: 4 additions & 0 deletions internal/cli/serverless/migration/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ const (
migrationDefinitionAllTemplate = `{
// Required migration mode. Use "ALL" for full + incremental.
"mode": "ALL",
// Optional import mode for full migration phase.
// Supported values: IMPORT_MODE_LOGICAL, IMPORT_MODE_PHYSICAL
// Note: Not applicable for mode = INCREMENTAL.
"importMode": "IMPORT_MODE_LOGICAL",
// Target TiDB Cloud user credentials used by the migration
"target": {
"user": "migration_user",
Expand Down
Loading
Loading