Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions core-services/prompt-registry/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
</scm>
<properties>
<project.rootdir>${project.basedir}/../../</project.rootdir>
<coverage.complexity>84%</coverage.complexity>
<coverage.line>90%</coverage.line>
<coverage.instruction>92%</coverage.instruction>
<coverage.complexity>92%</coverage.complexity>
<coverage.line>94%</coverage.line>
<coverage.instruction>95%</coverage.instruction>
<coverage.branch>100%</coverage.branch>
<coverage.method>80%</coverage.method>
<coverage.method>85%</coverage.method>
<coverage.class>100%</coverage.class>
</properties>

Expand Down Expand Up @@ -85,6 +85,10 @@
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.common.collect.Iterables;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.prompt.registry.client.PromptTemplatesApi;
import com.sap.ai.sdk.prompt.registry.model.MultiChatContent;
import com.sap.ai.sdk.prompt.registry.model.MultiChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSpecResponseFormat;
import com.sap.ai.sdk.prompt.registry.model.ResponseFormatJsonObject;
Expand All @@ -16,6 +23,8 @@
import com.sap.ai.sdk.prompt.registry.model.SingleChatTemplate;
import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor;
import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient;
import java.io.IOException;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -75,7 +84,7 @@ private static ApiClient addMixin(@Nonnull final AiCoreService service) {
@NoArgsConstructor(access = AccessLevel.PRIVATE)
private static class JacksonMixin {
@JsonTypeInfo(use = JsonTypeInfo.Id.NONE)
@JsonDeserialize(as = SingleChatTemplate.class)
@JsonDeserialize(using = PromptTemplateDeserializer.class)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Major)

I assume this is a breaking change. Or it is a fix. Anyway it requires a dedicated PR with semantic title: fix: [PromptRegistry] ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a fix to address the 2nd bullet point of the expected state :
image

So, shall I create a separate PR now with the 2nd point alone?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively just rename this PR, and keep everything as is.

LGTM

interface TemplateMixIn {}

@JsonTypeInfo(
Expand All @@ -90,4 +99,40 @@ interface TemplateMixIn {}
})
interface ResponseFormat {}
}

private static class PromptTemplateDeserializer extends JsonDeserializer<PromptTemplate> {

@Override
public PromptTemplate deserialize(
@Nonnull final JsonParser jsonParser,
@Nonnull final DeserializationContext deserializationContext)
throws IOException {

final JsonNode root = jsonParser.readValueAsTree();
final JsonNode roleNode = root.path("role");
final String role = roleNode.asText();
final JsonNode content = root.path("content");

if (!roleNode.isTextual()) {
throw JsonMappingException.from(
jsonParser, "PromptTemplate requires textual 'role' property.");
}

if (content.isTextual()) {
return SingleChatTemplate.create().role(role).content(content.asText());
}
if (content.isArray()) {
final var contentList = new ArrayList<MultiChatContent>();
for (final JsonNode item : content) {
contentList.add(jsonParser.getCodec().treeToValue(item, MultiChatContent.class));
}
return MultiChatTemplate.create().role(role).content(contentList);
}

throw JsonMappingException.from(
jsonParser,
"PromptTemplate content must be either a string or an array, but found: "
+ content.getNodeType());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package com.sap.ai.sdk.prompt.registry.spring;

import com.sap.ai.sdk.prompt.registry.model.ImageContent;
import com.sap.ai.sdk.prompt.registry.model.MultiChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionResponse;
import com.sap.ai.sdk.prompt.registry.model.SingleChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.TextContent;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import lombok.val;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand Down Expand Up @@ -34,17 +38,44 @@ public static List<Message> promptTemplateToMessages(
// TRANSFORM TEMPLATE TO SPRING AI MESSAGES
return res.stream()
.map(
(PromptTemplate t) -> {
final SingleChatTemplate message = (SingleChatTemplate) t;
return (Message)
switch (message.getRole()) {
case "system" -> new SystemMessage(message.getContent());
case "user" -> new UserMessage(message.getContent());
case "assistant" -> new AssistantMessage(message.getContent());
default ->
throw new IllegalArgumentException("Unknown role: " + message.getRole());
};
(PromptTemplate template) -> {
if (template instanceof SingleChatTemplate message) {
return fromRole(message.getRole(), message.getContent());
}
if (template instanceof MultiChatTemplate message) {
return fromRole(message.getRole(), getMultiTemplateTextContent(message));
}
throw new IllegalArgumentException(
"Unsupported PromptTemplate type: " + template.getClass().getName());
})
.toList();
}

@Nonnull
private static String getMultiTemplateTextContent(@Nonnull final MultiChatTemplate message) {
return message.getContent().stream()
.map(
item -> {
if (item instanceof TextContent textContent) {
return textContent.getText();
}
if (item instanceof ImageContent) {
throw new UnsupportedOperationException(
"MultiChatTemplate with image content is not supported by SpringAiConverter yet.");
}
throw new UnsupportedOperationException(
"Unsupported MultiChatContent type: " + item.getClass().getName());
})
.collect(Collectors.joining("\n"));
}

@Nonnull
private static Message fromRole(@Nonnull final String role, @Nonnull final String content) {
return switch (role) {
case "system" -> new SystemMessage(content);
case "user" -> new UserMessage(content);
case "assistant" -> new AssistantMessage(content);
default -> throw new IllegalArgumentException("Unknown role: " + role);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.prompt.registry.model.MultiChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateGetResponse;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionRequest;
import com.sap.ai.sdk.prompt.registry.model.ResponseFormatJsonObject;
import com.sap.ai.sdk.prompt.registry.model.ResponseFormatJsonSchema;
import com.sap.ai.sdk.prompt.registry.model.ResponseFormatText;
import com.sap.ai.sdk.prompt.registry.model.SingleChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.TextContent;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import java.util.Map;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -95,4 +101,66 @@ void testGetTemplateWithResponseFormatJsonSchema() {
assertThat(format.getJsonSchema().isStrict()).isFalse();
assertThat(format.getJsonSchema().getSchema()).isNotNull();
}

@Test
void testGetTemplateWithMultiChatTemplate() {
final var uuid = UUID.fromString("8f79fec4-ae07-4c35-96e3-df7f4a3f1df5");
final var response = client.getPromptTemplateByUuid(uuid);

assertThat(response.getSpec()).isNotNull();
assertThat(response.getSpec().getTemplate()).hasSize(2);
assertThat(response.getSpec().getTemplate().get(0)).isInstanceOf(SingleChatTemplate.class);
assertThat(response.getSpec().getTemplate().get(1)).isInstanceOf(MultiChatTemplate.class);

final var multiTemplate = (MultiChatTemplate) response.getSpec().getTemplate().get(1);
assertThat(multiTemplate.getRole()).isEqualTo("user");
assertThat(multiTemplate.getContent()).hasSize(2);
assertThat(multiTemplate.getContent().get(0)).isInstanceOf(TextContent.class);
assertThat(((TextContent) multiTemplate.getContent().get(0)).getText())
.isEqualTo("First content line");
assertThat(((TextContent) multiTemplate.getContent().get(1)).getText())
.isEqualTo("Second content line");
}

@Test
void testGetTemplateWithInvalidRoleType() {
final var uuid = UUID.fromString("45cb1358-0bf1-4f43-870b-00f14d0f9f16");

assertThatThrownBy(() -> client.getPromptTemplateByUuid(uuid))
.hasStackTraceContaining("PromptTemplate requires textual 'role' property.");
}

@Test
void testGetTemplateWithInvalidContentType() {
final var uuid = UUID.fromString("55cb1358-0bf1-4f43-870b-00f14d0f9f16");

assertThatThrownBy(() -> client.getPromptTemplateByUuid(uuid))
.hasStackTraceContaining(
"PromptTemplate content must be either a string or an array, but found: BOOLEAN");
}

@Test
void testParsePromptTemplateHotPath() {
final var request =
PromptTemplateSubstitutionRequest.create()
.inputParams(Map.of("inputExample", "I love football"));

final var response =
client.parsePromptTemplateByNameVersion(
"categorization", "0.0.1", "hotpath-serde", "default", null, false, request);

assertThat(response.getParsedPrompt()).hasSize(2);
assertThat(response.getParsedPrompt().get(0)).isInstanceOf(SingleChatTemplate.class);
assertThat(response.getParsedPrompt().get(1)).isInstanceOf(SingleChatTemplate.class);

final var systemTemplate = (SingleChatTemplate) response.getParsedPrompt().get(0);
assertThat(systemTemplate.getRole()).isEqualTo("system");
assertThat(systemTemplate.getContent())
.isEqualTo(
"You classify input text into the two following categories: Finance, Tech, Sports");

final var userTemplate = (SingleChatTemplate) response.getParsedPrompt().get(1);
assertThat(userTemplate.getRole()).isEqualTo("user");
assertThat(userTemplate.getContent()).isEqualTo("I love football");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.prompt.registry.PromptClient;
import com.sap.ai.sdk.prompt.registry.model.MultiChatContent;
import com.sap.ai.sdk.prompt.registry.model.MultiChatTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplate;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionRequest;
import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionResponse;
import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination;
import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination;
import java.util.List;
Expand Down Expand Up @@ -68,4 +72,65 @@ void testInvalidRoleThrowsException() {
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Unknown role: error");
}

@Test
void testMultiChatTemplateTextContentToSpringAi() {
var client = new PromptClient(SERVICE);
val promptResponse =
client.parsePromptTemplateByNameVersion(
"categorization",
"0.0.1",
"multi-text",
"default",
null,
false,
PromptTemplateSubstitutionRequest.create()
.inputParams(Map.of("inputExample", "I love football")));

List<Message> messages = SpringAiConverter.promptTemplateToMessages(promptResponse);
assertThat(messages)
.isEqualTo(List.of(new UserMessage("First content line\nSecond content line")));
}

@Test
void testMultiChatTemplateImageContentThrowsException() {
var client = new PromptClient(SERVICE);
val promptResponse =
client.parsePromptTemplateByNameVersion(
"categorization",
"0.0.1",
"multi-image",
"default",
null,
false,
PromptTemplateSubstitutionRequest.create()
.inputParams(Map.of("inputExample", "I love football")));

assertThatThrownBy(() -> SpringAiConverter.promptTemplateToMessages(promptResponse))
.isInstanceOf(UnsupportedOperationException.class)
.hasMessageContaining("image content is not supported");
}

@Test
void testUnsupportedPromptTemplateTypeThrowsException() {
final PromptTemplate unsupportedTemplate = new PromptTemplate() {};
final var promptResponse =
PromptTemplateSubstitutionResponse.create().parsedPrompt(List.of(unsupportedTemplate));

assertThatThrownBy(() -> SpringAiConverter.promptTemplateToMessages(promptResponse))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Unsupported PromptTemplate type");
}

@Test
void testUnsupportedMultiChatContentTypeThrowsException() {
final MultiChatContent unsupportedContent = new MultiChatContent() {};
final var message = MultiChatTemplate.create().role("user").content(unsupportedContent);
final var promptResponse =
PromptTemplateSubstitutionResponse.create().parsedPrompt(List.of(message));

assertThatThrownBy(() -> SpringAiConverter.promptTemplateToMessages(promptResponse))
.isInstanceOf(UnsupportedOperationException.class)
.hasMessageContaining("Unsupported MultiChatContent type");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"request": {
"method": "GET",
"url": "/v2/lm/promptTemplates/55cb1358-0bf1-4f43-870b-00f14d0f9f16"
},
"response": {
"status": 200,
"headers": {
"Content-Type": "application/json"
},
"jsonBody": {
"id": "55cb1358-0bf1-4f43-870b-00f14d0f9f16",
"name": "invalid-content",
"version": "0.0.1",
"scenario": "test-retrival",
"creationTimestamp": "2025-12-02T16:06:05.400000",
"managedBy": "imperative",
"isVersionHead": true,
"spec": {
"template": [
{
"role": "user",
"content": true
}
],
"tools": []
}
}
}
}
Loading
Loading