Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): predict reading level during epub import #1822

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@
<version>model-${model.version}</version>
</dependency>

<dependency>
<groupId>org.pmml4s</groupId>
<artifactId>pmml4s_3</artifactId>
<version>1.0.1</version>
</dependency>

<dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.elimu.model.contributor.Contributor;
import ai.elimu.model.contributor.ImageContributionEvent;
import ai.elimu.model.contributor.StoryBookContributionEvent;
import ai.elimu.model.v2.enums.ReadingLevel;
import ai.elimu.model.v2.enums.content.ImageFormat;
import ai.elimu.util.DiscordHelper;
import ai.elimu.util.ImageColorHelper;
Expand All @@ -39,6 +40,7 @@
import java.util.Arrays;
import java.util.Calendar;
import java.util.List;
import java.util.Map;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import org.apache.commons.io.FileUtils;
Expand Down Expand Up @@ -404,6 +406,22 @@ public String handleSubmit(
storyBookParagraphDao.create(storyBookParagraph);
}
}

List<StoryBookChapter> chapters = storyBookChapterDao.readAll(storyBook);
int chapterCount = chapters.size();
int paragraphCount = 0;
int wordCount = 0;
for (StoryBookChapter chapter : chapters) {
List<StoryBookParagraph> paragraphs = storyBookParagraphDao.readAll(chapter);
paragraphCount += paragraphs.size();
for (StoryBookParagraph paragraph : paragraphs) {
wordCount += paragraph.getOriginalText().split(" ").length;
}
}
ReadingLevel predictedReadingLevel = predictReadingLevel(chapterCount, paragraphCount, wordCount);
logger.info("predictedReadingLevel: " + predictedReadingLevel);
storyBook.setReadingLevel(predictedReadingLevel);
storyBookDao.update(storyBook);

if (!EnvironmentContextLoaderListener.PROPERTIES.isEmpty()) {
String contentUrl = "https://" + EnvironmentContextLoaderListener.PROPERTIES.getProperty("content.language").toLowerCase() + ".elimu.ai/content/storybook/edit/" + storyBook.getId();
Expand Down Expand Up @@ -518,4 +536,46 @@ private void storeImageContributionEvent(Image image, HttpSession session, HttpS
);
}
}

private ReadingLevel predictReadingLevel(int chapterCount, int paragraphCount, int wordCount) {
logger.info("predictReadingLevel");

// Load the machine learning model (https://github.com/elimu-ai/ml-storybook-reading-level)
String modelFilePath = getClass().getResource("step2_2_model.pmml").getFile();
logger.info("modelFilePath: " + modelFilePath);
org.pmml4s.model.Model model = org.pmml4s.model.Model.fromFile(modelFilePath);
logger.info("model: " + model);

// Prepare values (features) to pass to the model
Map<String, Double> values = Map.of(
"chapter_count", Double.valueOf(chapterCount),
"paragraph_count", Double.valueOf(paragraphCount),
"word_count", Double.valueOf(wordCount)
);
logger.info("values: " + values);

// Make prediction
logger.info("Arrays.toString(model.inputNames()): " + Arrays.toString(model.inputNames()));
Object[] valuesMap = Arrays.stream(model.inputNames())
.map(values::get)
.toArray();
logger.info("valuesMap: " + valuesMap);
Object[] results = model.predict(valuesMap);
logger.info("results: " + results);
logger.info("Arrays.toString(results): " + Arrays.toString(results));
Object result = results[0];
logger.info("result: " + result);
logger.info("result.getClass().getSimpleName(): " + result.getClass().getSimpleName());
Double resultAsDouble = (Double) result;
logger.info("resultAsDouble: " + resultAsDouble);
Integer resultAsInteger = resultAsDouble.intValue();
logger.info("resultAsInteger: " + resultAsInteger);

// Convert from number to ReadingLevel enum (e.g. "LEVEL2")
String readingLevelAsString = "LEVEL" + resultAsInteger;
logger.info("readingLevelAsString: " + readingLevelAsString);
ReadingLevel readingLevel = ReadingLevel.valueOf(readingLevelAsString);
logger.info("readingLevel: " + readingLevel);
return readingLevel;
}
Comment on lines +540 to +580
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enhance error handling and logging.

The predictReadingLevel method effectively handles the prediction process. However, consider adding error handling for potential issues like file not found or invalid model predictions. Additionally, improve logging by including more context about the prediction process.

try {
    org.pmml4s.model.Model model = org.pmml4s.model.Model.fromFile(modelFilePath);
    // Existing logic
} catch (FileNotFoundException e) {
    logger.error("Model file not found: " + modelFilePath, e);
    throw new RuntimeException("Model file not found", e);
} catch (Exception e) {
    logger.error("Error during model prediction", e);
    throw new RuntimeException("Error during model prediction", e);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_4" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.4">
<Header>
<Application name="SkLearn2PMML package" version="0.110.0"/>
<Timestamp>2024-08-17T11:40:01Z</Timestamp>
</Header>
<DataDictionary>
<DataField name="reading_level" optype="continuous" dataType="double"/>
<DataField name="chapter_count" optype="continuous" dataType="float"/>
<DataField name="paragraph_count" optype="continuous" dataType="float"/>
<DataField name="word_count" optype="continuous" dataType="float"/>
</DataDictionary>
<TreeModel functionName="regression" algorithmName="sklearn.tree._classes.DecisionTreeRegressor" missingValueStrategy="nullPrediction" noTrueChildStrategy="returnLastPrediction">
<MiningSchema>
<MiningField name="reading_level" usageType="target"/>
<MiningField name="word_count"/>
<MiningField name="chapter_count"/>
<MiningField name="paragraph_count"/>
</MiningSchema>
<LocalTransformations>
<DerivedField name="double(word_count)" optype="continuous" dataType="double">
<FieldRef field="word_count"/>
</DerivedField>
<DerivedField name="double(chapter_count)" optype="continuous" dataType="double">
<FieldRef field="chapter_count"/>
</DerivedField>
<DerivedField name="double(paragraph_count)" optype="continuous" dataType="double">
<FieldRef field="paragraph_count"/>
</DerivedField>
</LocalTransformations>
<Node score="4.0">
<True/>
<Node score="2.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="319.0"/>
<Node score="1.0">
<SimplePredicate field="double(chapter_count)" operator="lessOrEqual" value="15.5"/>
<Node score="1.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="210.0"/>
<Node score="1.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="127.0"/>
</Node>
<Node score="2.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="135.5"/>
</Node>
</Node>
<Node score="2.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="234.5"/>
</Node>
</Node>
<Node score="1.0">
<SimplePredicate field="double(paragraph_count)" operator="lessOrEqual" value="20.0"/>
</Node>
</Node>
<Node score="2.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="521.5"/>
<Node score="3.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="383.5"/>
</Node>
</Node>
<Node score="2.0">
<SimplePredicate field="double(paragraph_count)" operator="lessOrEqual" value="51.5"/>
<Node score="3.0">
<SimplePredicate field="double(paragraph_count)" operator="lessOrEqual" value="47.5"/>
<Node score="4.0">
<SimplePredicate field="double(word_count)" operator="lessOrEqual" value="559.5"/>
</Node>
</Node>
</Node>
</Node>
</TreeModel>
</PMML>
Loading