Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.responseformat.ResponseSchemaGenerator;
import com.microsoft.semantickernel.semanticfunctions.InputVariable;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionMetadata;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -159,14 +164,17 @@ private static String getSchemaForFunctionParameter(@Nullable InputVariable para
entries.add("\"type\":\"" + type + "\"");

// Add description if present
String description =null;
if (parameter != null && parameter.getDescription() != null && !parameter.getDescription()
.isEmpty()) {
String description = parameter.getDescription();
description = parameter.getDescription();
description = description.replaceAll("\\r?\\n|\\r", "");
description = description.replace("\"", "\\\"");

description = String.format("\"description\":\"%s\"", description);
entries.add(description);
entries.add(String.format("\"description\":\"%s\"", description));
}
// If custom type, generate schema
if("object".equalsIgnoreCase(type)) {
return getObjectSchema(parameter.getType(), description);
}

// Add enum options if parameter is an enum
Expand Down Expand Up @@ -219,4 +227,20 @@ private static String getJavaTypeToOpenAiFunctionType(String javaType) {
return "object";
}
}

private static String getObjectSchema(String type, String description){
String schema= "{ \"type\" : \"object\" }";
try {
Class<?> clazz = Class.forName(type);
schema = ResponseSchemaGenerator.jacksonGenerator().generateSchema(clazz);

} catch (ClassNotFoundException | SKException ignored) {

}
Map<String, Object> properties = BinaryData.fromString(schema).toObject(Map.class);
if(StringUtils.isNotBlank(description)) {
properties.put("description", description);
}
return BinaryData.fromObject(properties).toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.microsoft.semantickernel.orchestration.responseformat.JsonSchemaResponseFormat;
import com.microsoft.semantickernel.plugin.KernelPlugin;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

public class JsonSchemaTest {

Expand All @@ -24,4 +31,86 @@ public void jacksonGenerationTest() throws JsonProcessingException {
"\"type\":\"object\",\"properties\":{\"bar\":{}}"));
}

@Test
public void openAIFunctionTest() {
KernelPlugin plugin = KernelPluginFactory.createFromObject(
new TestPlugin(),
"test");

Assertions.assertNotNull(plugin);
Assertions.assertEquals(plugin.getName(), "test");
Assertions.assertEquals(plugin.getFunctions().size(), 3);

KernelFunction<?> testFunction = plugin.getFunctions()
.get("asyncPersonFunction");
OpenAIFunction openAIFunction = OpenAIFunction.build(
testFunction.getMetadata(),
plugin.getName());

String parameters = "{\"type\":\"object\",\"required\":[\"person\",\"input\"],\"properties\":{\"input\":{\"type\":\"string\",\"description\":\"input string\"},\"person\":{\"type\":\"object\",\"properties\":{\"age\":{\"type\":\"integer\",\"description\":\"The age of the person.\"},\"name\":{\"type\":\"string\",\"description\":\"The name of the person.\"},\"title\":{\"type\":\"string\",\"enum\":[\"MS\",\"MRS\",\"MR\"],\"description\":\"The title of the person.\"}},\"required\":[\"age\",\"name\",\"title\"],\"additionalProperties\":false,\"description\":\"input person\"}}}";
Assertions.assertEquals(parameters, openAIFunction.getFunctionDefinition().getParameters().toString());

}


public static class TestPlugin {

@DefineKernelFunction
public String testFunction(
@KernelFunctionParameter(name = "input", description = "input string") String input) {
return "test" + input;
}

@DefineKernelFunction(returnType = "int")
public Mono<Integer> asyncTestFunction(
@KernelFunctionParameter(name = "input") String input) {
return Mono.just(1);
}

@DefineKernelFunction(returnType = "int", description = "test function description",
name = "asyncPersonFunction", returnDescription = "test return description")
public Mono<Integer> asyncPersonFunction(
@KernelFunctionParameter(name = "person",description = "input person", type = Person.class) Person person,
@KernelFunctionParameter(name = "input", description = "input string") String input) {
return Mono.just(1);
}
}

private static enum Title {
MS,
MRS,
MR
}

public static class Person {
@JsonPropertyDescription("The name of the person.")
private String name;
@JsonPropertyDescription("The age of the person.")
private int age;
@JsonPropertyDescription("The title of the person.")
private Title title;


public Person(String name, int age) {
this.name = name;
this.age = age;
}

public String getName() {
return name;
}

public int getAge() {
return age;
}

public Title getTitle() {
return title;
}

public void setTitle(Title title) {
this.title = title;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public static void main(String[] args) throws Exception {
ChatCompletionService.class);

ContextVariableTypes
.addGlobalConverter(ContextVariableTypeConverter.builder(LightModel.class)
.toPromptString(new Gson()::toJson)
.build());
.addGlobalConverter(new LightModelTypeConverter());

KernelHooks hook = new KernelHooks();

Expand All @@ -99,9 +97,7 @@ public static void main(String[] args) throws Exception {
InvocationContext invocationContext = new Builder()
.withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY)
.withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true))
.withContextVariableConverter(ContextVariableTypeConverter.builder(LightModel.class)
.toPromptString(new Gson()::toJson)
.build())
.withContextVariableConverter(new LightModelTypeConverter())
.build();

// Create a history to store the conversation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.samples.demos.lights;

import com.fasterxml.jackson.annotation.JsonPropertyDescription;

public class LightModel {

@JsonPropertyDescription("The unique identifier of the light")
private int id;

@JsonPropertyDescription("The name of the light")
private String name;

@JsonPropertyDescription("The state of the light")
private Boolean isOn;

public LightModel(int id, String name, Boolean isOn) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.microsoft.semantickernel.samples.demos.lights;

import com.google.gson.Gson;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;

public class LightModelTypeConverter extends ContextVariableTypeConverter<LightModel> {
private static final Gson gson = new Gson();

public LightModelTypeConverter() {
super(
LightModel.class,
obj -> {
if(obj instanceof String) {
return gson.fromJson((String)obj, LightModel.class);
} else {
return gson.fromJson(gson.toJson(obj), LightModel.class);
}
},
(types, lightModel) -> gson.toJson(lightModel),
json -> gson.fromJson(json, LightModel.class)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ public List<LightModel> getLights() {
return lights;
}

@DefineKernelFunction(name = "add_light", description = "Adds a new light")
public String addLight(
@KernelFunctionParameter(name = "newLight", description = "new Light Details", type = LightModel.class) LightModel light) {
if( light != null) {
System.out.println("Adding light " + light.getName());
lights.add(light);
return "Light added";
}
return "Light failed to added";
}

@DefineKernelFunction(name = "change_state", description = "Changes the state of the light")
public LightModel changeState(
@KernelFunctionParameter(name = "id", description = "The ID of the light to change", type = int.class) int id,
Expand Down