diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java index e478f551e4c..f85b80c2d60 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -26,6 +26,7 @@ import google.registry.beam.common.RegistryJpaIO.Read; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.RegistryConfig.ConfigModule; +import google.registry.model.IdService; import google.registry.model.domain.Domain; import google.registry.model.reporting.Spec11ThreatMatch; import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; @@ -43,6 +44,7 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -154,26 +156,40 @@ private static KV parseRow(Object[] row) { static void saveToSql( PCollection> threatMatches, Spec11PipelineOptions options) { - String transformId = "Spec11 Threat Matches"; LocalDate date = LocalDate.parse(options.getDate(), ISODateTimeFormat.date()); - threatMatches.apply( + + PCollection spec11ThreatMatches = + threatMatches.apply( + "Construct objects", + ParDo.of( + new DoFn, Spec11ThreatMatch>() { + @ProcessElement + public void processElement( + @Element KV input, + OutputReceiver output) { + Spec11ThreatMatch spec11ThreatMatch = + new Spec11ThreatMatch.Builder() + .setThreatTypes( + ImmutableSet.of(ThreatType.valueOf(input.getValue().threatType()))) + .setCheckDate(date) + .setDomainName(input.getKey().domainName()) + .setDomainRepoId(input.getKey().domainRepoId()) + .setRegistrarId(input.getKey().registrarId()) + .setId(IdService.allocateId()) + .build(); + output.output(spec11ThreatMatch); + } + })); + + spec11ThreatMatches.apply("Prevent Fusing", Reshuffle.viaRandomKey()); + String transformId = "Spec11 Threat Matches"; + + spec11ThreatMatches.apply( "Write to Sql: " + transformId, - RegistryJpaIO.>write() + RegistryJpaIO.write() .withName(transformId) .withBatchSize(options.getSqlWriteBatchSize()) - .withShards(options.getSqlWriteShards()) - .withJpaConverter( - (kv) -> { - DomainNameInfo domainNameInfo = kv.getKey(); - return new Spec11ThreatMatch.Builder() - .setThreatTypes( - ImmutableSet.of(ThreatType.valueOf(kv.getValue().threatType()))) - .setCheckDate(date) - .setDomainName(domainNameInfo.domainName()) - .setDomainRepoId(domainNameInfo.domainRepoId()) - .setRegistrarId(domainNameInfo.registrarId()) - .build(); - })); + .withShards(options.getSqlWriteShards())); } static void saveToGcs( diff --git a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java index 5232df8cc7c..c02bddfe26d 100644 --- a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java +++ b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java @@ -22,6 +22,7 @@ import google.registry.model.Buildable; import google.registry.model.ImmutableObject; import google.registry.util.DomainNameUtils; +import java.io.Serializable; import java.util.Set; import javax.persistence.Column; import javax.persistence.Entity; @@ -39,7 +40,7 @@ @Index(name = "spec11threatmatch_tld_idx", columnList = "tld"), @Index(name = "spec11threatmatch_check_date_idx", columnList = "checkDate") }) -public class Spec11ThreatMatch extends ImmutableObject implements Buildable { +public class Spec11ThreatMatch extends ImmutableObject implements Buildable, Serializable { /** The type of threat detected. */ public enum ThreatType { diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index c9f56e42911..59f9b4be267 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -291,7 +291,7 @@ private void verifySaveToCloudSql() { ImmutableList sqlThreatMatches = Spec11ThreatMatchDao.loadEntriesByDate(jpaTm(), new LocalDate(2020, 1, 27)); assertThat(sqlThreatMatches) - .comparingElementsUsing(immutableObjectCorrespondence("id")) + .comparingElementsUsing(immutableObjectCorrespondence()) .containsExactlyElementsIn(sqlThreatMatches); }); }