Skip to content

Commit

Permalink
Merge pull request #1280 from edeandrea/rewrite-tools
Browse files Browse the repository at this point in the history
Rewrite auditing to use CDI events
  • Loading branch information
geoand authored Feb 13, 2025
2 parents 4959fb7 + 6bdaca0 commit 000f9e3
Show file tree
Hide file tree
Showing 21 changed files with 375 additions and 398 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,6 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName auditServiceSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceSupplierValue != null) {
auditServiceSupplierClassName = auditServiceSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceSupplierClassName, index, reflectiveClassProducer);
}

DotName moderationModelSupplierClassName = LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER;
AnnotationValue moderationModelSupplierValue = instance.value("moderationModelSupplier");
if (moderationModelSupplierValue != null) {
Expand Down Expand Up @@ -397,7 +390,6 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatMemoryProviderSupplierClassDotName,
retrievalAugmentorSupplierClassName,
customRetrievalAugmentorSupplierClassIsABean,
auditServiceSupplierClassName,
moderationModelSupplierClassName,
imageModelSupplierClassName,
determineChatMemorySeeder(declarativeAiServiceClassInfo, generatedClassOutput),
Expand Down Expand Up @@ -500,7 +492,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsRetrievalAugmentorBean = false;
boolean needsAuditServiceBean = false;
boolean needsModerationModelBean = false;
boolean needsImageModelBean = false;
boolean needsToolProviderBean = false;
Expand Down Expand Up @@ -545,10 +536,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getRetrievalAugmentorSupplierClassDotName().toString()
: null;

String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
? bi.getAuditServiceClassSupplierDotName().toString()
: null;

String moderationModelSupplierClassName = (bi.getModerationModelSupplierDotName() != null
? bi.getModerationModelSupplierDotName().toString()
: null);
Expand Down Expand Up @@ -621,7 +608,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
toolProviderSupplierClassName,
chatMemoryProviderSupplierClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName,
moderationModelSupplierClassName,
imageModelSupplierClassName,
chatMemorySeederClassName,
Expand Down Expand Up @@ -698,12 +684,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
}
}

if (LangChain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(LangChain4jDotNames.AUDIT_SERVICE) }, null));
needsAuditServiceBean = true;
}

if (LangChain4jDotNames.BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER.toString()
.equals(moderationModelSupplierClassName) && injectModerationModelBean) {

Expand Down Expand Up @@ -765,9 +745,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsRetrievalAugmentorBean) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.RETRIEVAL_AUGMENTOR));
}
if (needsAuditServiceBean) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.AUDIT_SERVICE));
}
if (needsModerationModelBean) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(LangChain4jDotNames.MODERATION_MODEL));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrievalAugmentorSupplierClassDotName;
private final boolean customRetrievalAugmentorSupplierClassIsABean;
private final DotName auditServiceClassSupplierDotName;
private final DotName moderationModelSupplierDotName;
private final DotName imageModelSupplierDotName;
private final DotName chatMemorySeederClassDotName;
Expand All @@ -40,7 +39,6 @@ public DeclarativeAiServiceBuildItem(
DotName chatMemoryProviderSupplierClassDotName,
DotName retrievalAugmentorSupplierClassDotName,
boolean customRetrievalAugmentorSupplierClassIsABean,
DotName auditServiceClassSupplierDotName,
DotName moderationModelSupplierDotName,
DotName imageModelSupplierDotName,
DotName chatMemorySeederClassDotName,
Expand All @@ -57,7 +55,6 @@ public DeclarativeAiServiceBuildItem(
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrievalAugmentorSupplierClassDotName = retrievalAugmentorSupplierClassDotName;
this.customRetrievalAugmentorSupplierClassIsABean = customRetrievalAugmentorSupplierClassIsABean;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
this.moderationModelSupplierDotName = moderationModelSupplierDotName;
this.imageModelSupplierDotName = imageModelSupplierDotName;
this.chatMemorySeederClassDotName = chatMemorySeederClassDotName;
Expand Down Expand Up @@ -97,10 +94,6 @@ public boolean isCustomRetrievalAugmentorSupplierClassIsABean() {
return customRetrievalAugmentorSupplierClassIsABean;
}

public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}

public DotName getModerationModelSupplierDotName() {
return moderationModelSupplierDotName;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.SeedMemory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;
Expand Down Expand Up @@ -91,11 +90,6 @@ public class LangChain4jDotNames {
static final DotName NO_RETRIEVAL_AUGMENTOR_SUPPLIER = DotName.createSimple(
RegisterAiService.NoRetrievalAugmentorSupplier.class);

static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class);

static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

static final DotName BEAN_IF_EXISTS_MODERATION_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsModerationModelSupplier.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.spi.services.AiServicesFactory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
Expand Down Expand Up @@ -46,11 +45,6 @@ public AiServices<T> tools(Collection<Object> objectsWithTools) {
return this;
}

public AiServices<T> auditService(AuditService auditService) {
quarkusAiServiceContext().auditService = auditService;
return this;
}

public AiServices<T> chatMemorySeeder(ChatMemorySeeder chatMemorySeeder) {
quarkusAiServiceContext().chatMemorySeeder = chatMemorySeeder;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
* Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by
Expand Down Expand Up @@ -114,15 +113,6 @@
*/
Class<? extends Supplier<RetrievalAugmentor>> retrievalAugmentor() default BeanIfExistsRetrievalAugmentorSupplier.class;

/**
* Configures the way to obtain the {@link AuditService} to use.
* By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link AuditService} instance is needed, a custom implementation of
* {@link Supplier<AuditService>} needs to be provided.
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Configures the way to obtain the {@link ModerationModel} to use.
* By default, Quarkus will look for a CDI bean that implements {@link ModerationModel} if at least one method is annotated
Expand Down Expand Up @@ -235,18 +225,6 @@ public RetrievalAugmentor get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean.
* If no such bean exists, then no audit service will be used.
*/
final class BeanIfExistsAuditServiceSupplier implements Supplier<AuditService> {

@Override
public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link ModerationModel} that the user has configured as a CDI bean.
* If no such bean exists, then no audit service will be used.
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.quarkiverse.langchain4j.audit;

import java.util.Optional;
import java.util.UUID;

/**
* Contains information about the source of an audit event
*/
public interface AuditSourceInfo {
/**
* The fully-qualified name of the interface where the llm interaction was initialized
*
* @see #methodName()
*/
String interfaceName();

/**
* The method name on {@link #interfaceName()} where the llm interaction was initiated
*
* @see #interfaceName()
*/
String methodName();

/**
* The position of the memory id parameter in {@link #methodParams()}, if one exists
*/
Optional<Integer> memoryIDParamPosition();

/**
* The parameters passed into the initial LLM call
*/
Object[] methodParams();

/**
* A unique identifier that identifies this entire interaction with the LLM
*/
UUID interactionId();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.quarkiverse.langchain4j.audit;

import java.util.Optional;

import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;

/**
* Invoked when the original user and system messages have been created
*/
public record InitialMessagesCreatedEvent(AuditSourceInfo sourceInfo, Optional<SystemMessage> systemMessage,
UserMessage userMessage) implements LLMInteractionEvent {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.quarkiverse.langchain4j.audit;

/**
* Invoked when the final result of the AiService method has been computed
*/
public record LLMInteractionCompleteEvent(AuditSourceInfo sourceInfo, Object result) implements LLMInteractionEvent {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.quarkiverse.langchain4j.audit;

public interface LLMInteractionEvent {
AuditSourceInfo sourceInfo();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.quarkiverse.langchain4j.audit;

/**
* Invoked when there was an exception computing the result of the AiService method
*/
public record LLMInteractionFailureEvent(AuditSourceInfo sourceInfo, Exception error) implements LLMInteractionEvent {
}
Loading

0 comments on commit 000f9e3

Please sign in to comment.