diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt new file mode 100644 index 00000000000..0ed8f42f86f --- /dev/null +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateChat.kt @@ -0,0 +1,246 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.ai.common.JSON +import com.google.firebase.ai.type.AutoFunctionDeclaration +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FirebaseAutoFunctionException +import com.google.firebase.ai.type.FunctionCallPart +import com.google.firebase.ai.type.FunctionResponsePart +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.InvalidStateException +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.RequestTimeoutException +import com.google.firebase.ai.type.TemplateAutoFunctionDeclaration +import com.google.firebase.ai.type.TemplateTool +import com.google.firebase.ai.type.TemplateToolConfig +import com.google.firebase.ai.type.content +import java.util.LinkedList +import java.util.concurrent.Semaphore +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.FlowCollector +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.flow.transform +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.jsonObject + +/** + * Representation of a multi-turn interaction with a server template model. + */ +@PublicPreviewAPI +public class TemplateChat( + private val model: TemplateGenerativeModel, + private val templateId: String, + private val inputs: Map, + public val history: MutableList = ArrayList(), + private val tools: List? = null, + private val toolConfig: TemplateToolConfig? = null, +) { + private var lock = Semaphore(1) + private var turns: Int = 0 + + /** + * Sends a message using the provided [prompt]; automatically providing the existing [history] as + * context. + * + * @param prompt The input that, together with the history, will be given to the model as the + * prompt. + */ + public suspend fun sendMessage(prompt: Content): GenerateContentResponse { + prompt.assertComesFromUser() + attemptLock() + var response: GenerateContentResponse + try { + val tempHistory = mutableListOf(prompt) + while (true) { + response = + model.generateContentWithHistory( + templateId, + inputs, + listOf(*history.toTypedArray(), *tempHistory.toTypedArray()), + tools, + toolConfig + ) + tempHistory.add(response.candidates.first().content) + val functionCallParts = + response.candidates.first().content.parts.filterIsInstance() + + if (functionCallParts.isEmpty()) { + break + } + if (model.requestOptions.autoFunctionCallingTurnLimit < ++turns) { + throw RequestTimeoutException("Request took too many turns", history = tempHistory) + } + if (!functionCallParts.all { hasFunction(it) }) { + break + } + val functionResponsePart = functionCallParts.map { executeFunction(it) } + tempHistory.add(Content("function", functionResponsePart)) + } + history.addAll(tempHistory) + return response + } finally { + lock.release() + turns = 0 + } + } + + /** + * Sends a message using the provided text [prompt]; automatically providing the existing [history] + * as context. + */ + public suspend fun sendMessage(prompt: String): GenerateContentResponse { + val content = content { text(prompt) } + return sendMessage(content) + } + + /** + * Sends a message using the provided [prompt]; automatically providing the existing [history] as + * context. Returns a flow. + */ + public fun sendMessageStream(prompt: Content): Flow { + prompt.assertComesFromUser() + attemptLock() + + val fullPrompt = history + prompt + val flow = model.generateContentWithHistoryStream(templateId, inputs, fullPrompt, tools, toolConfig) + val tempHistory = LinkedList() + tempHistory.add(prompt) + + return flow + .transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) } + .onCompletion { + turns = 0 + lock.release() + if (it == null) { + history.addAll(tempHistory) + } + } + } + + /** + * Sends a message using the provided text [prompt]; automatically providing the existing [history] + * as context. Returns a flow. + */ + public fun sendMessageStream(prompt: String): Flow { + val content = content { text(prompt) } + return sendMessageStream(content) + } + + private suspend fun automaticFunctionExecutingTransform( + transformer: FlowCollector, + tempHistory: MutableList, + response: GenerateContentResponse + ) { + val functionCallParts = + response.candidates.first().content.parts.filterIsInstance() + if (functionCallParts.isNotEmpty()) { + if (functionCallParts.all { hasFunction(it) }) { + if (model.requestOptions.autoFunctionCallingTurnLimit < ++turns) { + throw RequestTimeoutException("Request took too many turns", history = tempHistory) + } + val functionResponses = + Content("function", functionCallParts.map { executeFunction(it) }) + tempHistory.add(Content("model", functionCallParts)) + tempHistory.add(functionResponses) + model + .generateContentWithHistoryStream( + templateId, + inputs, + listOf(*history.toTypedArray(), *tempHistory.toTypedArray()), + tools, + toolConfig + ) + .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } + } else { + transformer.emit(response) + tempHistory.add(Content("model", functionCallParts)) + } + } else { + transformer.emit(response) + tempHistory.add(response.candidates.first().content) + } + } + + internal fun hasFunction(call: FunctionCallPart): Boolean { + if (tools == null) return false + return tools + .flatMap { it.templateAutoFunctionDeclarations } + .firstOrNull { it.name == call.name && it.functionReference != null } != null + } + + @OptIn(InternalSerializationApi::class) + internal suspend fun executeFunction(call: FunctionCallPart): FunctionResponsePart { + if (tools.isNullOrEmpty()) { + throw RuntimeException("No registered tools") + } + val tool = tools.flatMap { it.templateAutoFunctionDeclarations } + val declaration = + tool.firstOrNull() { it.name == call.name } + ?: throw RuntimeException("No registered function named ${call.name}") + return executeFunction( + call, + declaration as TemplateAutoFunctionDeclaration, + JsonObject(call.args).toString() + ) + } + + @OptIn(InternalSerializationApi::class) + internal suspend fun executeFunction( + functionCall: FunctionCallPart, + functionDeclaration: TemplateAutoFunctionDeclaration, + parameter: String + ): FunctionResponsePart { + val inputDeserializer = functionDeclaration.inputSchema.getSerializer() + val input = JSON.decodeFromString(inputDeserializer, parameter) + val functionReference = + functionDeclaration.functionReference + ?: throw RuntimeException("Function reference for ${functionDeclaration.name} is missing") + try { + val output = functionReference.invoke(input) + val outputSerializer = functionDeclaration.outputSchema?.getSerializer() + if (outputSerializer != null) { + return FunctionResponsePart.from( + JSON.encodeToJsonElement(outputSerializer, output).jsonObject + ) + .normalizeAgainstCall(functionCall) + } + return (output as FunctionResponsePart).normalizeAgainstCall(functionCall) + } catch (e: FirebaseAutoFunctionException) { + return FunctionResponsePart.from(JsonObject(mapOf("error" to JsonPrimitive(e.message)))) + .normalizeAgainstCall(functionCall) + } + } + + private fun Content.assertComesFromUser() { + if (role !in listOf("user", "function")) { + throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.") + } + } + + private fun attemptLock() { + if (!lock.tryAcquire()) { + throw InvalidStateException( + "This chat instance currently has an ongoing request, please wait for it to complete " + + "before sending more messages" + ) + } + } +} diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt index 8672813fff3..5d66400973c 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt @@ -29,6 +29,8 @@ import com.google.firebase.ai.type.PublicPreviewAPI import com.google.firebase.ai.type.RequestOptions import com.google.firebase.ai.type.ResponseStoppedException import com.google.firebase.ai.type.SerializationException +import com.google.firebase.ai.type.TemplateTool +import com.google.firebase.ai.type.TemplateToolConfig import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import kotlinx.coroutines.flow.Flow @@ -79,6 +81,8 @@ internal constructor( * * @param templateId The ID of the prompt template to use. * @param inputs A map of variables to substitute into the template. + * @param tools A list of [TemplateTool]s the model may use to generate content. + * @param toolConfig The [TemplateToolConfig] that defines how the model handles the tools provided. * @return The content generated by the model. * @throws [FirebaseAIException] if the request failed. * @see [FirebaseAIException] for types of errors. @@ -86,10 +90,33 @@ internal constructor( public suspend fun generateContent( templateId: String, inputs: Map, + tools: List? = null, + toolConfig: TemplateToolConfig? = null, + ): GenerateContentResponse = generateContentWithHistory(templateId, inputs, null, tools, toolConfig) + + /** + * Generates content from a prompt template and inputs. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @param history Prior history in the conversation. + * @param tools A list of [TemplateTool]s the model may use to generate content. + * @param toolConfig The [TemplateToolConfig] that defines how the model handles the tools provided. + * @return The content generated by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + @PublicPreviewAPI + public suspend fun generateContentWithHistory( + templateId: String, + inputs: Map, + history: List? = null, + tools: List? = null, + toolConfig: TemplateToolConfig? = null, ): GenerateContentResponse = try { controller - .templateGenerateContent("$templateUri$templateId", constructRequest(inputs)) + .templateGenerateContent("$templateUri$templateId", constructRequest(inputs, history, tools, toolConfig)) .toPublic() .validate() } catch (e: Throwable) { @@ -101,29 +128,80 @@ internal constructor( * * @param templateId The ID of the prompt template to use. * @param inputs A map of variables to substitute into the template. + * @param tools A list of [TemplateTool]s the model may use to generate content. + * @param toolConfig The [TemplateToolConfig] that defines how the model handles the tools provided. * @return A [Flow] which will emit responses as they are returned by the model. * @throws [FirebaseAIException] if the request failed. * @see [FirebaseAIException] for types of errors. */ public fun generateContentStream( templateId: String, - inputs: Map + inputs: Map, + tools: List? = null, + toolConfig: TemplateToolConfig? = null, + ): Flow = generateContentWithHistoryStream(templateId, inputs, null, tools, toolConfig) + + /** + * Generates content as a stream from a prompt template, inputs, and history. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @param history Prior history in the conversation. + * @param tools A list of [TemplateTool]s the model may use to generate content. + * @param toolConfig The [TemplateToolConfig] that defines how the model handles the tools provided. + * @return A [Flow] which will emit responses as they are returned by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + @PublicPreviewAPI + public fun generateContentWithHistoryStream( + templateId: String, + inputs: Map, + history: List? = null, + tools: List? = null, + toolConfig: TemplateToolConfig? = null, ): Flow = controller - .templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs)) + .templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs, history, tools, toolConfig)) .catch { throw FirebaseAIException.from(it) } .map { it.toPublic().validate() } + /** + * Creates a [TemplateChat] instance using this model with the optionally provided history. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template for the session. + * @param history Prior history in the conversation. + * @param tools A list of [TemplateTool]s the model may use to generate content. + * @param toolConfig The [TemplateToolConfig] that defines how the model handles the tools provided. + * @return The initialized [TemplateChat] instance. + */ + @PublicPreviewAPI + public fun startChat( + templateId: String, + inputs: Map, + history: List = emptyList(), + tools: List? = null, + toolConfig: TemplateToolConfig? = null, + ): TemplateChat = TemplateChat(this, templateId, inputs, history.toMutableList(), tools, toolConfig) + internal fun constructRequest( inputs: Map, - history: List? = null + history: List? = null, + tools: List? = null, + toolConfig: TemplateToolConfig? = null, ): TemplateGenerateContentRequest { return TemplateGenerateContentRequest( Json.parseToJsonElement(JSONObject(inputs).toString()).jsonObject, - history?.let { it.map { it.toTemplateInternal() } } + history?.let { it.map { c -> c.toTemplateInternal() } }, + tools?.map { it.toInternal() }, + toolConfig?.toInternal() ) } + internal val requestOptions: RequestOptions + get() = controller.requestOptions + private fun GenerateContentResponse.validate() = apply { if (candidates.isEmpty() && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt index e3a4afcb56c..6a12a23d18d 100644 --- a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt @@ -49,7 +49,9 @@ internal data class GenerateContentRequest( @Serializable internal data class TemplateGenerateContentRequest( val inputs: JsonObject, - val history: List? + val history: List?, + val tools: List? = null, + @SerialName("tool_config") val toolConfig: ToolConfig.Internal? = null, ) : Request @Serializable internal data class TemplateGenerateImageRequest(val inputs: JsonObject) : Request diff --git a/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/TemplateTool.kt b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/TemplateTool.kt new file mode 100644 index 00000000000..07d649da22c --- /dev/null +++ b/ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/TemplateTool.kt @@ -0,0 +1,86 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.type + +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Contains a set of tools (like function declarations) that the server template model has access to. + */ +public class TemplateTool +@OptIn(PublicPreviewAPI::class) +internal constructor( + internal val functionDeclarations: List?, +) { + + public val templateAutoFunctionDeclarations: List> + get() = functionDeclarations?.filterIsInstance>() ?: emptyList() + + @OptIn(PublicPreviewAPI::class) + internal fun toInternal() = + Tool.Internal( + functionDeclarations = functionDeclarations?.map { it.toInternal() } + ) + + public companion object { + + /** + * Creates a [TemplateTool] instance that provides the model with access to the [functionDeclarations]. + * + * @param functionDeclarations The list of functions that this tool allows the model access to. + */ + @JvmStatic + public fun functionDeclarations( + functionDeclarations: List, + ): TemplateTool { + return TemplateTool(functionDeclarations) + } + } +} + +/** + * A function declaration for a template tool. + */ +public open class TemplateFunctionDeclaration( + public val name: String, + public val parameters: Schema? = null +) { + internal fun toInternal(): FunctionDeclaration.Internal { + return FunctionDeclaration.Internal(name, parameters?.toInternal()) + } +} + +/** + * A function declaration for a template tool that can be called by the model automatically. + */ +public class TemplateAutoFunctionDeclaration( + name: String, + public val inputSchema: Schema, + public val outputSchema: Schema? = null, + public val functionReference: (suspend (I) -> O), +) : TemplateFunctionDeclaration(name, inputSchema) + +/** + * Config for template tools to use with server prompts. + */ +public class TemplateToolConfig { + internal fun toInternal(): ToolConfig.Internal? { + return null // Empty config payload as defined in flutter API + } +}