Skip to content

Commit

Permalink
Add use_public_ips to Go Dataflow Runner (#28308)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmccluskey authored Sep 5, 2023
1 parent fc4b459 commit 7c1c250
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
22 changes: 22 additions & 0 deletions sdks/go/pkg/beam/runners/dataflow/dataflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ var (
network = flag.String("network", "", "GCP network (optional)")
subnetwork = flag.String("subnetwork", "", "GCP subnetwork (optional)")
noUsePublicIPs = flag.Bool("no_use_public_ips", false, "Workers must not use public IP addresses (optional)")
usePublicIPs = flag.Bool("use_public_ips", true, "Workers must use public IP addresses (optional)")
tempLocation = flag.String("temp_location", "", "Temp location (optional)")
workerMachineType = flag.String("worker_machine_type", "", "GCE machine type (optional)")
machineType = flag.String("machine_type", "", "alias of worker_machine_type (optional)")
Expand Down Expand Up @@ -245,6 +246,16 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error)
return dataflowlib.Execute(ctx, model, opts, workerURL, modelURL, *endpoint, *jobopts.Async)
}

func isFlagPassed(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}

func getJobOptions(ctx context.Context, streaming bool) (*dataflowlib.JobOptions, error) {
project := gcpopts.GetProjectFromFlagOrEnvironment(ctx)
if project == "" {
Expand Down Expand Up @@ -294,6 +305,17 @@ func getJobOptions(ctx context.Context, streaming bool) (*dataflowlib.JobOptions
return nil, errors.Wrapf(err, "error reading --transform_name_mapping flag as JSON")
}
}
if *usePublicIPs == *noUsePublicIPs {
useSet := isFlagPassed("use_public_ips")
noUseSet := isFlagPassed("no_use_public_ips")
// If use_public_ips was explicitly set but no_use_public_ips was not, use that value
// We take the explicit value of no_use_public_ips if it was set but use_public_ips was not.
if useSet && !noUseSet {
*noUsePublicIPs = !*usePublicIPs
} else if useSet && noUseSet {
return nil, errors.New("exactly one of usePublicIPs and noUsePublicIPs must be true, please check that only one is true")
}
}

hooks.SerializeHooksToOptions()

Expand Down
77 changes: 77 additions & 0 deletions sdks/go/pkg/beam/runners/dataflow/dataflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,81 @@ func TestGetJobOptions_AliasAreEffective(t *testing.T) {
}
}

func TestGetJobOptions_BadTruePublicIPs(t *testing.T) {
resetGlobals()
*usePublicIPs = true
*noUsePublicIPs = true

opts, err := getJobOptions(context.Background(), false)
if err == nil {
t.Error("getJobOptions() returned error nil, want an error")
}
if opts != nil {
t.Errorf("getJobOptions() returned JobOptions when it should not have, got %#v, want nil", opts)
}
}

func TestGetJobOptions_BadFalsePublicIPs(t *testing.T) {
resetGlobals()
*usePublicIPs = false
*noUsePublicIPs = false

opts, err := getJobOptions(context.Background(), false)
if err == nil {
t.Error("getJobOptions() returned error nil, want an error")
}
if opts != nil {
t.Errorf("getJobOptions() returned JobOptions when it should not have, got %#v, want nil", opts)
}
}

func TestGetJobOptions_DefaultPublicIPs(t *testing.T) {
resetGlobals()
*labels = `{"label1": "val1", "label2": "val2"}`
*stagingLocation = "gs://testStagingLocation"
*minCPUPlatform = "testPlatform"
*flexRSGoal = "FLEXRS_SPEED_OPTIMIZED"
*dataflowServiceOptions = "opt1,opt2"

*gcpopts.Project = "testProject"
*gcpopts.Region = "testRegion"

*jobopts.Experiments = "use_runner_v2,use_portable_job_submission"
*jobopts.JobName = "testJob"

opts, err := getJobOptions(context.Background(), false)
if err != nil {
t.Fatalf("getJobOptions() returned error %q, want %q", err, "nil")
}
if got, want := opts.NoUsePublicIPs, false; got != want {
t.Errorf("getJobOptions().NoUsePublicIPs = %t, want %t", got, want)
}
}

func TestGetJobOptions_NoUsePublicIPs(t *testing.T) {
resetGlobals()
*labels = `{"label1": "val1", "label2": "val2"}`
*stagingLocation = "gs://testStagingLocation"
*minCPUPlatform = "testPlatform"
*flexRSGoal = "FLEXRS_SPEED_OPTIMIZED"
*dataflowServiceOptions = "opt1,opt2"
*noUsePublicIPs = true

*gcpopts.Project = "testProject"
*gcpopts.Region = "testRegion"

*jobopts.Experiments = "use_runner_v2,use_portable_job_submission"
*jobopts.JobName = "testJob"

opts, err := getJobOptions(context.Background(), false)
if err != nil {
t.Fatalf("getJobOptions() returned error %q, want %q", err, "nil")
}
if got, want := opts.NoUsePublicIPs, true; got != want {
t.Errorf("getJobOptions().NoUsePublicIPs = %t, want %t", got, want)
}
}

func getFieldFromOpt(fieldName string, opts *dataflowlib.JobOptions) string {
return reflect.ValueOf(opts).Elem().FieldByName(fieldName).String()
}
Expand All @@ -447,6 +522,8 @@ func resetGlobals() {
*stagingLocation = ""
*transformMapping = ""
*update = false
*usePublicIPs = true
*noUsePublicIPs = false
*workerHarnessImage = ""
*workerMachineType = ""
*machineType = ""
Expand Down

0 comments on commit 7c1c250

Please sign in to comment.