Skip to content

Commit

Permalink
[OPIK-311] Update prompt (#555)
Browse files Browse the repository at this point in the history
* [OPIK-309] Create prompt endpoint

* [OPIK-309] Expose API contracts

* [OPIK-309] Expose API contracts

* [OPIK-310] Expose get prompts api

* Add logic to create first version when specified

* Initial commit

* Address PR review

* [OPIK-314] Create prompt version endpoint

* Fix error

* [OPIK-311] Update prompt

* Add missing test

* Fix Test name
  • Loading branch information
thiagohora authored Nov 5, 2024
1 parent a186a36 commit a3e1ab4
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.Valid;
import lombok.Builder;

@Builder(toBuilder = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
public record Prompt(
@JsonView( {
Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id,
@JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name,
@JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class,
Prompt.View.Updatable.class}) @NotBlank String name,
@JsonView({Prompt.View.Public.class,
Prompt.View.Write.class,
Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
Prompt.View.Detail.class,
Prompt.View.Updatable.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
@JsonView({
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Public.class,
Expand All @@ -50,6 +52,9 @@ public static class Public {

public static class Detail {
}

public static class Updatable {
}
}

@Builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import com.comet.opik.api.PromptVersion;
import com.comet.opik.api.PromptVersionRetrieve;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.domain.IdGenerator;
import com.comet.opik.domain.PromptService;
import com.comet.opik.infrastructure.auth.RequestContext;
Expand All @@ -34,20 +32,6 @@
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Min;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.DefaultValue;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
Expand Down Expand Up @@ -166,15 +150,14 @@ public Response getPromptById(@PathParam("id") UUID id) {
@RateLimited
public Response updatePrompt(
@PathParam("id") UUID id,
@RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) {

@RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Updatable.class) @Valid Prompt prompt) {
String workspaceId = requestContext.get().getWorkspaceId();

log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId);

promptService.update(id, prompt);
log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId);

return Response.status(Response.Status.NOT_IMPLEMENTED).build();
return Response.noContent().build();
}

@DELETE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,34 @@
interface PromptDAO {

@SqlUpdate("INSERT INTO prompts (id, name, description, created_by, last_updated_by, workspace_id) " +
"VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspaceId)")
void save(@Bind("workspaceId") String workspaceId, @BindMethods("bean") Prompt prompt);
"VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspace_id)")
void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") Prompt prompt);

@SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId")
Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId);
@SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspace_id")
Prompt findById(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId);

@SqlQuery("SELECT * FROM prompts " +
" WHERE workspace_id = :workspace_Id " +
" WHERE workspace_id = :workspace_id " +
" <if(name)> AND name like concat('%', :name, '%') <endif> " +
" ORDER BY id DESC " +
" LIMIT :limit OFFSET :offset ")
@UseStringTemplateEngine
@AllowUnusedBindings
List<Prompt> find(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId,
List<Prompt> find(@Define("name") @Bind("name") String name, @Bind("workspace_id") String workspaceId,
@Bind("offset") int offset, @Bind("limit") int limit);

@SqlQuery("SELECT COUNT(id) FROM prompts " +
" WHERE workspace_id = :workspace_Id " +
" WHERE workspace_id = :workspace_id " +
" <if(name)> AND name like concat('%', :name, '%') <endif> ")
@UseStringTemplateEngine
@AllowUnusedBindings
long count(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId);
long count(@Define("name") @Bind("name") String name, @Bind("workspace_id") String workspaceId);

@SqlQuery("SELECT * FROM prompts WHERE name = :name AND workspace_id = :workspace_id")
Prompt findByName(@Bind("name") String name, @Bind("workspace_id") String workspaceId);

@SqlUpdate("UPDATE prompts SET name = :bean.name, description = :bean.description, last_updated_by = :bean.lastUpdatedBy "
+
" WHERE id = :bean.id AND workspace_id = :workspace_id")
int update(@Bind("workspace_id") String workspaceId, @BindMethods("bean") Prompt updatedPrompt);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.inject.Singleton;
import jakarta.ws.rs.NotFoundException;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -31,6 +32,8 @@ public interface PromptService {
PromptPage find(String name, int page, int size);

PromptVersion createPromptVersion(CreatePromptVersion promptVersion);

void update(@NonNull UUID id, Prompt prompt);
}

@Singleton
Expand All @@ -40,12 +43,13 @@ class PromptServiceImpl implements PromptService {

private static final String ALREADY_EXISTS = "Prompt id or name already exists";
private static final String VERSION_ALREADY_EXISTS = "Prompt version already exists";

private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull IdGenerator idGenerator;
private final @NonNull TransactionTemplate transactionTemplate;

@Override
public Prompt create(Prompt prompt) {
public Prompt create(@NonNull Prompt prompt) {

String workspaceId = requestContext.get().getWorkspaceId();
String userName = requestContext.get().getUserName();
Expand Down Expand Up @@ -201,6 +205,36 @@ public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createProm
}
}

@Override
public void update(@NonNull UUID id, @NonNull Prompt prompt) {
String workspaceId = requestContext.get().getWorkspaceId();
String userName = requestContext.get().getUserName();

EntityConstraintHandler
.handle(() -> updatePrompt(id, prompt, userName, workspaceId))
.withError(this::newPromptConflict);
}

private Prompt updatePrompt(UUID id, Prompt prompt, String userName, String workspaceId) {
Prompt updatedPrompt = prompt.toBuilder()
.lastUpdatedBy(userName)
.id(id)
.build();

return transactionTemplate.inTransaction(WRITE, handle -> {
PromptDAO promptDAO = handle.attach(PromptDAO.class);

if (promptDAO.update(workspaceId, updatedPrompt) > 0) {
log.info("Updated prompt with id '{}'", id);
} else {
log.info("Prompt with id '{}' not found", id);
throw new NotFoundException("Prompt not found");
}

return updatedPrompt;
});
}

private PromptVersion retryableCreateVersion(String workspaceId, CreatePromptVersion request, Prompt prompt,
String userName) {
return EntityConstraintHandler.handle(() -> {
Expand Down
Loading

0 comments on commit a3e1ab4

Please sign in to comment.