diff --git a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java index f6d416a466e..6db902618a3 100644 --- a/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java +++ b/core/src/main/java/google/registry/beam/spec11/Spec11Pipeline.java @@ -26,6 +26,7 @@ import google.registry.beam.common.RegistryJpaIO.Read; import google.registry.beam.spec11.SafeBrowsingTransforms.EvaluateSafeBrowsingFn; import google.registry.config.RegistryConfig.ConfigModule; +import google.registry.model.IdService; import google.registry.model.domain.Domain; import google.registry.model.reporting.Spec11ThreatMatch; import google.registry.model.reporting.Spec11ThreatMatch.ThreatType; @@ -45,6 +46,7 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -154,25 +156,36 @@ private static KV parseRow(Object[] row) { static void saveToSql( PCollection> threatMatches, Spec11PipelineOptions options) { - String transformId = "Spec11 Threat Matches"; LocalDate date = LocalDate.parse(options.getDate(), ISODateTimeFormat.date()); - threatMatches.apply( - "Write to Sql: " + transformId, - RegistryJpaIO.>write() - .withName(transformId) - .withBatchSize(options.getSqlWriteBatchSize()) - .withJpaConverter( - (kv) -> { - DomainNameInfo domainNameInfo = kv.getKey(); - return new Spec11ThreatMatch.Builder() - .setThreatTypes( - ImmutableSet.of(ThreatType.valueOf(kv.getValue().threatType()))) - .setCheckDate(date) - .setDomainName(domainNameInfo.domainName()) - .setDomainRepoId(domainNameInfo.domainRepoId()) - .setRegistrarId(domainNameInfo.registrarId()) - .build(); - })); + String transformId = "Spec11 Threat Matches"; + threatMatches + .apply( + "Construct objects", + ParDo.of( + new DoFn, Spec11ThreatMatch>() { + @ProcessElement + public void processElement( + @Element KV input, + OutputReceiver output) { + Spec11ThreatMatch spec11ThreatMatch = + new Spec11ThreatMatch.Builder() + .setThreatTypes( + ImmutableSet.of(ThreatType.valueOf(input.getValue().threatType()))) + .setCheckDate(date) + .setDomainName(input.getKey().domainName()) + .setDomainRepoId(input.getKey().domainRepoId()) + .setRegistrarId(input.getKey().registrarId()) + .setId(IdService.allocateId()) + .build(); + output.output(spec11ThreatMatch); + } + })) + .apply("Prevent Fusing", Reshuffle.viaRandomKey()) + .apply( + "Write to Sql: " + transformId, + RegistryJpaIO.write() + .withName(transformId) + .withBatchSize(options.getSqlWriteBatchSize())); } static void saveToGcs( diff --git a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java index 5232df8cc7c..c02bddfe26d 100644 --- a/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java +++ b/core/src/main/java/google/registry/model/reporting/Spec11ThreatMatch.java @@ -22,6 +22,7 @@ import google.registry.model.Buildable; import google.registry.model.ImmutableObject; import google.registry.util.DomainNameUtils; +import java.io.Serializable; import java.util.Set; import javax.persistence.Column; import javax.persistence.Entity; @@ -39,7 +40,7 @@ @Index(name = "spec11threatmatch_tld_idx", columnList = "tld"), @Index(name = "spec11threatmatch_check_date_idx", columnList = "checkDate") }) -public class Spec11ThreatMatch extends ImmutableObject implements Buildable { +public class Spec11ThreatMatch extends ImmutableObject implements Buildable, Serializable { /** The type of threat detected. */ public enum ThreatType { diff --git a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java index 0a11816ad75..ddd5b774e5c 100644 --- a/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java +++ b/core/src/test/java/google/registry/beam/spec11/Spec11PipelineTest.java @@ -280,9 +280,9 @@ private void verifySaveToGcs() throws Exception { private void verifySaveToCloudSql() { tm().transact( () -> { - ImmutableList sqlThreatMatches = + ImmutableList spec11ThreatMatches = Spec11ThreatMatchDao.loadEntriesByDate(tm(), new LocalDate(2020, 1, 27)); - assertThat(sqlThreatMatches) + assertThat(spec11ThreatMatches) .comparingElementsUsing(immutableObjectCorrespondence("id")) .containsExactlyElementsIn(sqlThreatMatches); });