-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs(samples): Retail. Prediction samples (#412)
* Update README * the bash scripts are added * Update README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Update user_environment_setup.sh * Update user_import_data_to_catalog.sh * Update README.md * Add predictions samples. * Add predictions samples unit-tests. * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md Co-authored-by: tetiana-karasova <tetiana.karasova@gmail.com> Co-authored-by: t-karasova <tkarasova@google.com> Co-authored-by: t-karasova <91195610+t-karasova@users.noreply.github.com> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com> Co-authored-by: Karl Weinmeister <11586922+kweinmeister@users.noreply.github.com>
- Loading branch information
1 parent
aac9cac
commit cbc98d7
Showing
6 changed files
with
482 additions
and
0 deletions.
There are no files selected for viewing
89 changes: 89 additions & 0 deletions
89
retail/interactive-tutorials/src/main/java/prediction/FilteringPrediction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright 2022 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* [START retail_prediction_get_prediction_with_filtering] | ||
* Call Retail API to get predictions from Recommendation AI using filtering. | ||
*/ | ||
|
||
package prediction; | ||
|
||
import com.google.cloud.retail.v2.PredictRequest; | ||
import com.google.cloud.retail.v2.PredictResponse; | ||
import com.google.cloud.retail.v2.PredictionServiceClient; | ||
import com.google.cloud.retail.v2.Product; | ||
import com.google.cloud.retail.v2.ProductDetail; | ||
import com.google.cloud.retail.v2.UserEvent; | ||
import com.google.protobuf.Value; | ||
import java.io.IOException; | ||
|
||
public class FilteringPrediction { | ||
|
||
public static void main(String[] args) { | ||
String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); | ||
String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); | ||
String predictPlacement = | ||
String.format( | ||
"projects/%s/locations/global/catalogs/default_catalog/placements/%s", | ||
projectId, placementId); | ||
|
||
predict(predictPlacement); | ||
} | ||
|
||
public static void predict(String predictPlacement) { | ||
try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { | ||
PredictResponse predictResponse = | ||
predictionServiceClient.predict(getPredictRequest(predictPlacement)); | ||
System.out.printf("Predict response: %n%s", predictResponse); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private static PredictRequest getPredictRequest(String predictPlacement) { | ||
// create product object | ||
Product product = | ||
Product.newBuilder() | ||
.setId("55106") // Id of real product | ||
.build(); | ||
|
||
// create product detail object | ||
ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); | ||
|
||
// create user event object | ||
UserEvent userEvent = | ||
UserEvent.newBuilder() | ||
.setEventType("detail-page-view") | ||
.setVisitorId("281790") // Unique identifier to track visitors | ||
.addProductDetails(productDetail) | ||
.build(); | ||
|
||
PredictRequest predictRequest = | ||
PredictRequest.newBuilder() | ||
.setPlacement(predictPlacement) | ||
.setUserEvent(userEvent) | ||
// TRY DIFFERENT FILTER HERE: | ||
.setFilter("filterOutOfStockItems") | ||
// TRY TO UPDATE `strictFiltering` HERE: | ||
.putParams("strictFiltering", Value.newBuilder().setBoolValue(true).build()) | ||
.build(); | ||
System.out.printf("Predict request: %n%s", predictRequest); | ||
|
||
return predictRequest; | ||
} | ||
} | ||
|
||
// [END retail_prediction_get_prediction_with_filtering] |
89 changes: 89 additions & 0 deletions
89
retail/interactive-tutorials/src/main/java/prediction/PredictionWithParameters.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright 2022 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* [START retail_prediction_get_prediction_with_params] | ||
* Call Retail API to get predictions from Recommendation AI using parameters. | ||
*/ | ||
|
||
package prediction; | ||
|
||
import com.google.cloud.retail.v2.PredictRequest; | ||
import com.google.cloud.retail.v2.PredictResponse; | ||
import com.google.cloud.retail.v2.PredictionServiceClient; | ||
import com.google.cloud.retail.v2.Product; | ||
import com.google.cloud.retail.v2.ProductDetail; | ||
import com.google.cloud.retail.v2.UserEvent; | ||
import com.google.protobuf.Value; | ||
import java.io.IOException; | ||
|
||
public class PredictionWithParameters { | ||
|
||
public static void main(String[] args) { | ||
String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); | ||
String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); | ||
String predictPlacement = | ||
String.format( | ||
"projects/%s/locations/global/catalogs/default_catalog/placements/%s", | ||
projectId, placementId); | ||
|
||
predict(predictPlacement); | ||
} | ||
|
||
public static void predict(String predictPlacement) { | ||
try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { | ||
PredictResponse predictResponse = | ||
predictionServiceClient.predict(getPredictRequest(predictPlacement)); | ||
System.out.printf("Predict response: %n%s", predictResponse); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private static PredictRequest getPredictRequest(String predictPlacement) { | ||
// create product object | ||
Product product = | ||
Product.newBuilder() | ||
.setId("55106") // Id of real product | ||
.build(); | ||
|
||
// create product detail object | ||
ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); | ||
|
||
// create user event object | ||
UserEvent userEvent = | ||
UserEvent.newBuilder() | ||
.setEventType("detail-page-view") | ||
.setVisitorId("281790") // Unique identifier to track visitors | ||
.addProductDetails(productDetail) | ||
.build(); | ||
|
||
PredictRequest predictRequest = | ||
PredictRequest.newBuilder() | ||
.setPlacement(predictPlacement) // Placement is used to identify the Serving Config name | ||
.setUserEvent(userEvent) // Context about the user is required for event logging | ||
// TRY TO ADD/UPDATE PARAMETERS `priceRerankLevel` OR `diversityLevel` HERE: | ||
.putParams( | ||
"priceRerankLevel", | ||
Value.newBuilder().setStringValue("low-price-reranking").build()) | ||
.build(); | ||
System.out.printf("Predict request: %n%s", predictRequest); | ||
|
||
return predictRequest; | ||
} | ||
} | ||
|
||
// [END retail_prediction_get_prediction_with_params] |
86 changes: 86 additions & 0 deletions
86
retail/interactive-tutorials/src/main/java/prediction/SimplePrediction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Copyright 2022 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* [START retail_prediction_get_simple_prediction] | ||
* Call Retail API to get predictions from Recommendation AI using simple request. | ||
*/ | ||
|
||
package prediction; | ||
|
||
import com.google.cloud.retail.v2.PredictRequest; | ||
import com.google.cloud.retail.v2.PredictResponse; | ||
import com.google.cloud.retail.v2.PredictionServiceClient; | ||
import com.google.cloud.retail.v2.Product; | ||
import com.google.cloud.retail.v2.ProductDetail; | ||
import com.google.cloud.retail.v2.UserEvent; | ||
import com.google.protobuf.Value; | ||
import java.io.IOException; | ||
|
||
public class SimplePrediction { | ||
|
||
public static void main(String[] args) { | ||
String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); | ||
String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); | ||
String predictPlacement = | ||
String.format( | ||
"projects/%s/locations/global/catalogs/default_catalog/placements/%s", | ||
projectId, placementId); | ||
|
||
predict(predictPlacement); | ||
} | ||
|
||
public static void predict(String predictPlacement) { | ||
try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create()) { | ||
PredictResponse predictResponse = | ||
predictionServiceClient.predict(getPredictRequest(predictPlacement)); | ||
System.out.printf("Predict response: %n%s", predictResponse); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private static PredictRequest getPredictRequest(String predictPlacement) { | ||
// create product object | ||
Product product = | ||
Product.newBuilder() | ||
.setId("55106") // Id of real product | ||
.build(); | ||
|
||
// create product detail object | ||
ProductDetail productDetail = ProductDetail.newBuilder().setProduct(product).build(); | ||
|
||
// create user event object | ||
UserEvent userEvent = | ||
UserEvent.newBuilder() | ||
.setEventType("detail-page-view") | ||
.setVisitorId("281790") // Unique identifier to track visitors | ||
.addProductDetails(productDetail) | ||
.build(); | ||
|
||
PredictRequest predictRequest = | ||
PredictRequest.newBuilder() | ||
.setPlacement(predictPlacement) // Placement is used to identify the Serving Config name | ||
.setUserEvent(userEvent) // Context about the user is required for event logging | ||
.putParams("returnProduct", Value.newBuilder().setBoolValue(true).build()) | ||
.build(); | ||
System.out.printf("Predict request: %n%s", predictRequest); | ||
|
||
return predictRequest; | ||
} | ||
} | ||
|
||
// [END retail_prediction_get_simple_prediction] |
68 changes: 68 additions & 0 deletions
68
retail/interactive-tutorials/src/test/java/prediction/FilteringPredictionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* Copyright 2022 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package prediction; | ||
|
||
import static com.google.common.truth.Truth.assertThat; | ||
import static prediction.FilteringPrediction.predict; | ||
|
||
import java.io.ByteArrayOutputStream; | ||
import java.io.IOException; | ||
import java.io.PrintStream; | ||
import org.junit.After; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.junit.runners.JUnit4; | ||
|
||
@RunWith(JUnit4.class) | ||
public class FilteringPredictionTest { | ||
|
||
private ByteArrayOutputStream bout; | ||
private PrintStream originalPrintStream; | ||
|
||
@Before | ||
public void setUp() throws IOException, InterruptedException { | ||
bout = new ByteArrayOutputStream(); | ||
PrintStream out = new PrintStream(bout); | ||
originalPrintStream = System.out; | ||
System.setOut(out); | ||
} | ||
|
||
@Test | ||
public void testPredict() { | ||
String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); | ||
String placementId = System.getenv("GOOGLE_CLOUD_PLACEMENT"); | ||
String predictPlacement = | ||
String.format( | ||
"projects/%s/locations/global/catalogs/default_catalog/placements/%s", | ||
projectId, placementId); | ||
|
||
predict(predictPlacement); | ||
|
||
String outputResult = bout.toString(); | ||
|
||
assertThat(outputResult).contains("Predict request"); | ||
assertThat(outputResult).contains("filter: \"filterOutOfStockItems\""); | ||
assertThat(outputResult).contains("Predict response"); | ||
} | ||
|
||
@After | ||
public void tearDown() { | ||
System.out.flush(); | ||
System.setOut(originalPrintStream); | ||
} | ||
} |
Oops, something went wrong.