diff --git a/pkg/controller/v1alpha2/experiment/manifest/generator.go b/pkg/controller/v1alpha2/experiment/manifest/generator.go index 085ad752deb..70713041692 100644 --- a/pkg/controller/v1alpha2/experiment/manifest/generator.go +++ b/pkg/controller/v1alpha2/experiment/manifest/generator.go @@ -8,12 +8,12 @@ import ( "text/template" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2" apiv1alpha2 "github.com/kubeflow/katib/pkg/api/v1alpha2" commonv1alpha2 "github.com/kubeflow/katib/pkg/common/v1alpha2" "github.com/kubeflow/katib/pkg/util/v1alpha2/katibclient" - "sigs.k8s.io/controller-runtime/pkg/client" ) const ( @@ -22,6 +22,7 @@ const ( // Generator is the type for manifests Generator. type Generator interface { + InjectClient(c client.Client) GetRunSpec(e *experimentsv1alpha2.Experiment, experiment, trial, namespace string) (string, error) GetRunSpecWithHyperParameters(e *experimentsv1alpha2.Experiment, experiment, trial, namespace string, hps []*apiv1alpha2.ParameterAssignment) (string, error) GetMetricsCollectorManifest(experimentName string, trialName string, jobKind string, namespace string, metricNames []string, mcs *experimentsv1alpha2.MetricsCollectorSpec) (*bytes.Buffer, error) @@ -40,6 +41,10 @@ func New(c client.Client) Generator { } } +func (g *DefaultGenerator) InjectClient(c client.Client) { + g.client.InjectClient(c) +} + func (g *DefaultGenerator) GetMetricsCollectorManifest(experimentName string, trialName string, jobKind string, namespace string, metricNames []string, mcs *experimentsv1alpha2.MetricsCollectorSpec) (*bytes.Buffer, error) { var mtp *template.Template = nil var err error diff --git a/pkg/controller/v1alpha2/experiment/validation_webhook.go b/pkg/controller/v1alpha2/experiment/validation_webhook.go index 9705a5639ba..fa85536ce76 100644 --- a/pkg/controller/v1alpha2/experiment/validation_webhook.go +++ b/pkg/controller/v1alpha2/experiment/validation_webhook.go @@ -67,6 +67,7 @@ var _ inject.Client = &experimentValidator{} // InjectClient injects the client. func (v *experimentValidator) InjectClient(c client.Client) error { v.client = c + v.Validator.InjectClient(c) return nil } diff --git a/pkg/controller/v1alpha2/experiment/validator/validator.go b/pkg/controller/v1alpha2/experiment/validator/validator.go index faf51af3156..128749c54f9 100644 --- a/pkg/controller/v1alpha2/experiment/validator/validator.go +++ b/pkg/controller/v1alpha2/experiment/validator/validator.go @@ -9,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" k8syaml "k8s.io/apimachinery/pkg/util/yaml" logf "sigs.k8s.io/controller-runtime/pkg/runtime/log" + "sigs.k8s.io/controller-runtime/pkg/client" commonapiv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/common/v1alpha2" experimentsv1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2" @@ -21,6 +22,7 @@ var log = logf.Log.WithName("experiment-controller") type Validator interface { ValidateExperiment(instance *experimentsv1alpha2.Experiment) error + InjectClient(c client.Client) } type DefaultValidator struct { @@ -35,6 +37,10 @@ func New(generator manifest.Generator, managerClient managerclient.ManagerClient } } +func (g *DefaultValidator) InjectClient(c client.Client) { + g.Generator.InjectClient(c) +} + func (g *DefaultValidator) ValidateExperiment(instance *experimentsv1alpha2.Experiment) error { if !instance.IsCreated() { if err := g.validateForCreate(instance); err != nil { diff --git a/pkg/mock/v1alpha2/experiment/manifest/producer.go b/pkg/mock/v1alpha2/experiment/manifest/producer.go index fa3ab2ac160..51e7aef56dd 100644 --- a/pkg/mock/v1alpha2/experiment/manifest/producer.go +++ b/pkg/mock/v1alpha2/experiment/manifest/producer.go @@ -10,6 +10,7 @@ import ( v1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2" v1alpha20 "github.com/kubeflow/katib/pkg/api/v1alpha2" reflect "reflect" + client "sigs.k8s.io/controller-runtime/pkg/client" ) // MockGenerator is a mock of Generator interface @@ -79,3 +80,15 @@ func (mr *MockGeneratorMockRecorder) GetRunSpecWithHyperParameters(arg0, arg1, a mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRunSpecWithHyperParameters", reflect.TypeOf((*MockGenerator)(nil).GetRunSpecWithHyperParameters), arg0, arg1, arg2, arg3, arg4) } + +// InjectClient mocks base method +func (m *MockGenerator) InjectClient(arg0 client.Client) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InjectClient", arg0) +} + +// InjectClient indicates an expected call of InjectClient +func (mr *MockGeneratorMockRecorder) InjectClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InjectClient", reflect.TypeOf((*MockGenerator)(nil).InjectClient), arg0) +} diff --git a/pkg/mock/v1alpha2/util/katibclient/katibclient.go b/pkg/mock/v1alpha2/util/katibclient/katibclient.go index 89ebb3f9941..4145be12f5c 100644 --- a/pkg/mock/v1alpha2/util/katibclient/katibclient.go +++ b/pkg/mock/v1alpha2/util/katibclient/katibclient.go @@ -8,6 +8,7 @@ import ( gomock "github.com/golang/mock/gomock" v1alpha2 "github.com/kubeflow/katib/pkg/api/operators/apis/experiment/v1alpha2" reflect "reflect" + client "sigs.k8s.io/controller-runtime/pkg/client" ) // MockClient is a mock of Client interface @@ -129,6 +130,18 @@ func (mr *MockClientMockRecorder) GetTrialTemplates(arg0 ...interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTrialTemplates", reflect.TypeOf((*MockClient)(nil).GetTrialTemplates), arg0...) } +// InjectClient mocks base method +func (m *MockClient) InjectClient(arg0 client.Client) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InjectClient", arg0) +} + +// InjectClient indicates an expected call of InjectClient +func (mr *MockClientMockRecorder) InjectClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InjectClient", reflect.TypeOf((*MockClient)(nil).InjectClient), arg0) +} + // UpdateMetricsCollectorTemplates mocks base method func (m *MockClient) UpdateMetricsCollectorTemplates(arg0 map[string]string, arg1 ...string) error { m.ctrl.T.Helper() diff --git a/pkg/util/v1alpha2/katibclient/katib_client.go b/pkg/util/v1alpha2/katibclient/katib_client.go index e31ca254047..c5bdc253d59 100644 --- a/pkg/util/v1alpha2/katibclient/katib_client.go +++ b/pkg/util/v1alpha2/katibclient/katib_client.go @@ -28,6 +28,7 @@ import ( ) type Client interface { + InjectClient(c client.Client) GetExperimentList(namespace ...string) (*experimentsv1alpha2.ExperimentList, error) CreateExperiment(experiment *experimentsv1alpha2.Experiment, namespace ...string) error GetConfigMap(name string, namespace ...string) (map[string]string, error) @@ -58,6 +59,10 @@ func NewClient(options client.Options) (Client, error) { }, nil } +func (k *KatibClient) InjectClient(c client.Client) { + k.client = c +} + func (k *KatibClient) GetExperimentList(namespace ...string) (*experimentsv1alpha2.ExperimentList, error) { ns := getNamespace(namespace...) expList := &experimentsv1alpha2.ExperimentList{}