Привет!
Постараюсь изложить здесь свой опыт с реализацией RAG и функциональных вызовов для LLM модели, запущенной локально на телефоне.
Для cloud моделей общая идея не будет отличаться, так что почитав, вы сможете понять, как это работает и для чего используется.
Интересно здесь то, что успешность работы сильно зависит от правильности написания промпта того, что мы хотим от модели, то есть просто хороший код здесь не работает, нужна еще гуманитарная часть- объяснить все модели и надеяться на то, что она правильно все поняла.
Что такое RAG
В LLM модель при обучении помещаются терабайты информации, но информацию о вас, вашем проекте, ваших документах и вашем коте она может не знать.
Вы можете описать все это в запросе к модели, в промпте, но если информации много- размер контекста модели может закончится до того, как вы дойдете до вопроса, либо вам надоест набирать текст.
Чтобы модель знала информацию, на которой она не была обучена, используется RAG (Retrieval Augmented Generation).
У RAG может быть много разных реализаций, но общий смысл таков:
К вашему сообщению добавляется дополнительная информация, обычно в текстовом виде, например:
- Результат интернет запроса
- Данные из базы данных
- Текст из прикрепленного документа, и так далее
В этом примере я буду использовать векторную базу, которая отлично подходит для работы с LLM моделями.
Что такое векторная база данных
Перед запросами к модели нужно скормить базе данных ваши данные, чтобы потом модель могла найти наиболее подходящие из них для вашего запроса.
При добавлении в базу текста:
- Текст разбивается на куски
- При помощи специальной embedding модели Gecko для каждого куска генерируется вектор с 768 измерениями, смысл которого максимально соответствует содержимому куска текста
- Вектор вместе с текстом записывается в базу
При поиске текста в базе:
- Поисковый запрос при помощи модели Gecko преобразуется в вектор
- В базе ищется вектор, максимально близкий к вектору поискового запроса
- Для этого вектора вытаскивается текст
- Текст добавляется к вашему запросу к LLM модели
Что такое Function Calling
Разница с RAG в том, что в этом случае модель сама решает, обращаться ли ей к векторной базе или другому источнику для получения информации.
Для этого:
- в промпте указывается, что у модели есть доступ к функциональным вызовам, описывается их формат, описываются условия, при которых модели нужно делать этот вызов
- если модель решает сделать такой вызов- она возвращает служебную информацию с запросом на вызов
- парсер понимает, что нужно сделать вызов, читает данные из векторной базы, добавляет их в очередь сообщений и отправляет запрос снова.
В этом примере я буду использовать
Движок: Google AI Edge MediaPipe
LLM модель: gemma-3n-E4B-it-int4
RAG / Functional Calling: On-Device RAG SDK & On-Device Function Calling SDK
Embedding модель: Gecko-110m-en
Как это должно работать- можно почитать здесь
RAG:
https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation
https://en.wikipedia.org/wiki/Retrieval-augmented_generation
https://www.ibm.com/think/topics/retrieval-augmented-generation
Function Calling:
https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling
https://huggingface.co/docs/hugs/guides/function-calling
https://platform.openai.com/docs/guides/function-calling
https://medium.com/@danushidk507/function-calling-in-llm-e537b286a4fd
В этом примере будет 3 класса
MediaPipeEngineCommon: здесь будут храниться общие компоненты для работы с векторной базой
MediaPipeEngineWithRag: здесь можно запустить генерацию, добавив к запросу данные из векторной базы
MediaPipeEngineWithTools: здесь можно запустить генерацию и модель сама решает, делать ли запрос к векторной базе
Зависимости в проекте, в новом Gradle формате:
localagentsRag = "0.2.0" localagentsFc = "0.1.0" tasksGenai = "0.10.25" tasksText = "0.10.26.1" tasksVision = "0.10.26.1" tensorflowLite = "2.17.0" kotlinxCoroutinesGuava = "1.10.2" tasks-genai = { module = "com.google.mediapipe:tasks-genai", version.ref = "tasksGenai" } tasks-text = { module = "com.google.mediapipe:tasks-text", version.ref = "tasksText" } tasks-vision = { module = "com.google.mediapipe:tasks-vision", version.ref = "tasksVision" } tensorflow-lite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tensorflowLite" } localagents-rag = { module = "com.google.ai.edge.localagents:localagents-rag", version.ref = "localagentsRag" } localagents-fc = { module = "com.google.ai.edge.localagents:localagents-fc", version.ref = "localagentsFc" } kotlinx-coroutines-guava = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-guava", version.ref = "kotlinxCoroutinesGuava" }
MediaPipeEngineCommon:
import com.google.ai.edge.localagents.rag.chunking.TextChunker import com.google.ai.edge.localagents.rag.memory.DefaultSemanticTextMemory import com.google.ai.edge.localagents.rag.memory.SqliteVectorStore import com.google.ai.edge.localagents.rag.models.Embedder import com.google.ai.edge.localagents.rag.prompt.PromptBuilder interface MediaPipeEngineCommon { var chunker: TextChunker var embedder: Embedder<String> var vectorStore: SqliteVectorStore var promptBuilder: PromptBuilder var semanticMemory: DefaultSemanticTextMemory fun init( geckoModelPath: String, // Gecko_256_quant.tflite tokenizerModelPath: String, // sentencepiece.model useGpuForEmbeddings: Boolean = true, ) fun saveTextToVectorStore( text: String, chunkOverlap: Int = 20, chunkTokenSize: Int = 128, chunkMaxSymbolsSize: Int = 1000, chunkBySentences: Boolean = false, ): String? fun readEmbeddingVectors(): List<VectorStoreEntity> suspend fun readEmbeddingVectors( query: String, topK: Int, minSimilarityScore: Float, ): List<VectorStoreEntity> fun makeSQLRequest(query: String): Boolean }
import android.app.Application import android.database.Cursor import android.database.sqlite.SQLiteDatabase import com.google.ai.edge.localagents.rag.chunking.TextChunker import com.google.ai.edge.localagents.rag.memory.DefaultSemanticTextMemory import com.google.ai.edge.localagents.rag.memory.SqliteVectorStore import com.google.ai.edge.localagents.rag.memory.VectorStoreRecord import com.google.ai.edge.localagents.rag.models.EmbedData import com.google.ai.edge.localagents.rag.models.Embedder import com.google.ai.edge.localagents.rag.models.EmbeddingRequest import com.google.ai.edge.localagents.rag.models.GeckoEmbeddingModel import com.google.ai.edge.localagents.rag.prompt.PromptBuilder import com.google.common.collect.ImmutableList import com.romankryvolapov.offlineailauncher.common.extensions.toDurationString import com.romankryvolapov.offlineailauncher.common.models.common.LogUtil.logDebug import com.romankryvolapov.offlineailauncher.common.models.common.LogUtil.logError import kotlinx.coroutines.guava.await import java.io.File import java.nio.ByteBuffer import java.nio.ByteOrder import java.util.Optional class MediaPipeEngineCommonImpl( private val application: Application ) : MediaPipeEngineCommon { companion object { private const val TAG = "CommonComponentsTag" private const val GECKO_EMBEDDING_MODEL_DIMENSION = 768 private const val PROMPT_TEMPLATE: String = "You are an assistant for question-answering tasks. Here are the things I want to remember: {0} Use the things I want to remember, answer the following question the user has: {1}" } override lateinit var chunker: TextChunker override lateinit var embedder: Embedder<String> override lateinit var vectorStore: SqliteVectorStore override lateinit var promptBuilder: PromptBuilder override lateinit var semanticMemory: DefaultSemanticTextMemory override fun init( geckoModelPath: String, tokenizerModelPath: String, useGpuForEmbeddings: Boolean, ) { logDebug("init", TAG) chunker = TextChunker() // в embedder добавляем путь до модели Gecko-110m-en // я использую версию Gecko_256_quant.tflite здесь 256 это максимальный размер текстового входа // эта версия оптимальна с точки зрения размера кусков текста и быстродействия // важно- далее в коде мы передаем, на какие куски разбивать текст, это зависит от параметров модели embedder = GeckoEmbeddingModel( geckoModelPath, Optional.of(tokenizerModelPath), useGpuForEmbeddings, ) // здесь я буду использовать уже готовую SQLite базу в приложении // в нее просто добавится еще одна таблица rag_vector_store // с колонками text для текста и embeddings для вектора val database = File(application.getDatabasePath("database").absolutePath) if (!database.exists()) { logError("startEngine database not exists", TAG) } // по идее можно сд елать и кастомную реализацию, наследующую // интерфейс VectorStore<String>, но метод getNearestRecords должен // быть реализован правильно и работать быстро, он ищет ближайшие вестора vectorStore = SqliteVectorStore( GECKO_EMBEDDING_MODEL_DIMENSION, database.absolutePath ) semanticMemory = DefaultSemanticTextMemory( vectorStore, embedder ) promptBuilder = PromptBuilder( PROMPT_TEMPLATE ) logDebug("init ready", TAG) } override fun saveTextToVectorStore( text: String, // на сколько залезать в текст до фрагмента, здесь 20 chunkOverlap: Int, // Обратите внимание, что размер вроде как в токенах, // но он используется для chunker и может не соответствовать // размеру токенов для embedder, здесь chunkTokenSize 128 chunkTokenSize: Int, // при разбивании при помощи chunkBySentences // размер предложений может быть большим // если он привысит возможности embedder модели, она выдаст ошибку // для этого используется обрезка до максимального размера // здесь он 1000 символов chunkMaxSymbolsSize: Int, // использовать метод разбивки по предложениям chunkBySentences: Boolean, ): String? { logDebug("saveTextToVectorStore text length: ${text.length}", TAG) // таймер, чтобы понять, насколько быстро работает val start = System.currentTimeMillis() val chunks: List<String> = if (chunkBySentences) chunker.chunkBySentences( text, chunkTokenSize, ).filter { it.isNotBlank() }.map { chunk -> if (chunk.length > chunkMaxSymbolsSize) { logError("saveTextToVectorStore crop chunk", TAG) chunk.substring(0, chunkMaxSymbolsSize) } else { chunk } } else chunker.chunk( text, chunkTokenSize, chunkOverlap ).filter { it.isNotBlank() }.map { chunk -> if (chunk.length > chunkMaxSymbolsSize) { logError("saveTextToVectorStore crop chunk", TAG) chunk.substring(0, chunkMaxSymbolsSize) } else { chunk } } val end = System.currentTimeMillis() val delta = end - start logDebug("saveTextToVectorStore chunks delta: ${delta.toDurationString()} size: ${chunks.size}", TAG) chunks.forEach { logDebug("length: ${it.length}", TAG) } if (chunks.isEmpty()) { logError("saveTextToVectorStore chunks.isEmpty()", TAG) return "Chunks is empty" } return try { // генерация вектора происходит внутри semanticMemory val result: Boolean? = semanticMemory.recordBatchedMemoryItems( ImmutableList.copyOf(chunks) )?.get() val end = System.currentTimeMillis() val delta = end - start logDebug("saveTextToVectorStore ready delta: ${delta.toDurationString()} result: $result", TAG) null } catch (t: Throwable) { logError("saveTextToVectorStore failed: ${t.message}", t, TAG) t.message } } // поиска по запросу query, найдет все похожие на запрос куски текста override suspend fun readEmbeddingVectors( query: String, // количество результатов запроса к базе topK: Int, // насколько вектор запроса query должен быть похожим на запись в базе // 0.0 = искать все записи, отсортировать по самым похожим // 1.0 = только идеальное совпадение // я использую значения 0.6 - 0.8 minSimilarityScore: Float, ): List<VectorStoreEntity> { logDebug("readEmbeddingVectors query: $query", TAG) val queryEmbedData: EmbedData<String> = EmbedData.create( query, EmbedData.TaskType.RETRIEVAL_QUERY ) val embeddingRequest: EmbeddingRequest<String> = EmbeddingRequest .create( listOf(queryEmbedData) ) val vector: ImmutableList<Float> = try { embedder.getEmbeddings(embeddingRequest).await() } catch (t: Throwable) { logError("readEmbeddingVectors: embedding failed: ${t.message}", t, TAG) return emptyList() } logDebug("searchDocsInternal vector size: ${vector.size}", TAG) if (vector.isEmpty()) { logError("readEmbeddingVectors vector.isEmpty()", TAG) return emptyList() } val hits: ImmutableList<VectorStoreRecord<String>> = try { vectorStore.getNearestRecords( vector, topK, minSimilarityScore ) } catch (t: Throwable) { logError("readEmbeddingVectors: vector search failed: ${t.message}", t, TAG) return emptyList() } if (hits.isEmpty()) { logError("readEmbeddingVectors hits.isEmpty()", TAG) return emptyList() } val result = hits.map { VectorStoreEntity( id = null, text = it.data, embedding = it.embeddings ) } logDebug("readEmbeddingVectors\nsize: ${result.size}\nresult: $result", TAG) return result } // просто выводит все записи в базе override fun readEmbeddingVectors(): List<VectorStoreEntity> { logDebug("readEmbeddingPreview", TAG) var cursor: Cursor? = null var database: SQLiteDatabase? = null return try { val databaseFile = File(application.getDatabasePath("database").absolutePath) database = SQLiteDatabase.openDatabase( databaseFile.absolutePath, null, SQLiteDatabase.OPEN_READONLY ) cursor = database.rawQuery("SELECT ROWID, text, embeddings FROM rag_vector_store", null) val result = mutableListOf<VectorStoreEntity>() while (cursor.moveToNext()) { val rowId = cursor.getLong(0) val text = cursor.getString(1) val blob = cursor.getBlob(2) val buffer = ByteBuffer.wrap(blob).order(ByteOrder.LITTLE_ENDIAN) val floats = mutableListOf<Float>() while (buffer.hasRemaining()) { floats.add(buffer.float) } result.add( VectorStoreEntity( id = rowId, text = text, embedding = floats ) ) } logDebug("readEmbeddingPreview\nsize: ${result.size}\nresult: $result", TAG) result } catch (t: Throwable) { logError("readEmbeddingPreview failed: ${t.message}", t, TAG) emptyList() } finally { cursor?.close() database?.close() } } // можно написать свой запрос и он выполнится, // например "DELETE FROM rag_vector_store" override fun makeSQLRequest(query: String): Boolean { logDebug("makeSQLRequest query: $query", TAG) var cursor: Cursor? = null var database: SQLiteDatabase? = null return try { val databaseFile = File(application.getDatabasePath("database").absolutePath) database = SQLiteDatabase.openDatabase( databaseFile.absolutePath, null, SQLiteDatabase.OPEN_READWRITE ) cursor = database.rawQuery(query, null) val result = cursor.moveToFirst() logDebug("makeSQLRequest result: $result", TAG) result } catch (t: Throwable) { logError("makeSQLRequest failed: ${t.message}", t, TAG) false } finally { cursor?.close() database?.close() } } }
MediaPipeEngineWithRag:
import kotlinx.coroutines.flow.Flow import java.io.File interface MediaPipeEngineWithRag { fun startEngine( modelFile: File, isSupportImages: Boolean = false, engineParams: MediaPipeEngineParams, ) fun resetSession() fun generateResponse( prompt: String, topK: Int = 5, minSimilarityScore: Float = 0.6F, ): Flow<ResultEmittedData<String>> }
import android.app.Application import com.google.ai.edge.localagents.rag.chains.ChainConfig import com.google.ai.edge.localagents.rag.chains.RetrievalAndInferenceChain import com.google.ai.edge.localagents.rag.models.AsyncProgressListener import com.google.ai.edge.localagents.rag.models.LanguageModelResponse import com.google.ai.edge.localagents.rag.models.MediaPipeLlmBackend import com.google.ai.edge.localagents.rag.retrieval.RetrievalConfig import com.google.ai.edge.localagents.rag.retrieval.RetrievalConfig.TaskType import com.google.ai.edge.localagents.rag.retrieval.RetrievalRequest import com.google.common.util.concurrent.FutureCallback import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors import com.google.mediapipe.tasks.genai.llminference.GraphOptions import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference.LlmInferenceOptions import com.google.mediapipe.tasks.genai.llminference.LlmInferenceSession.LlmInferenceSessionOptions import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.callbackFlow import java.io.File import java.util.concurrent.Executor class MediaPipeEngineWithRagImpl( private val application: Application, private val common: MediaPipeEngineCommon, ) : MediaPipeEngineWithRag { companion object { private const val TAG = "MediaPipeEngineWithRagTag" } private var chainConfig: ChainConfig<String>? = null private var retrievalAndInferenceChain: RetrievalAndInferenceChain? = null private var engineMediaPipe: LlmInference? = null private var sessionOptions: LlmInferenceSessionOptions? = null private var mediaPipeLanguageModel: MediaPipeLlmBackend? = null private var interfaceOptions: LlmInferenceOptions? = null private val executor: Executor = MoreExecutors.directExecutor() private var future: ListenableFuture<LanguageModelResponse>? = null override fun startEngine( modelFile: File, isSupportImages: Boolean, engineParams: MediaPipeEngineParams, ) { logDebug("startEngine", TAG) interfaceOptions = createInterfaceOptions( modelFile = modelFile, engineParams = engineParams, isSupportImages = isSupportImages, ) engineMediaPipe = LlmInference.createFromOptions( application, interfaceOptions ) if (engineMediaPipe == null) { logError("startEngine llmInference == null", TAG) return } sessionOptions = createSessionOptions( engineParams = engineParams, isSupportImages = isSupportImages, ) mediaPipeLanguageModel = MediaPipeLlmBackend( application.applicationContext, interfaceOptions, sessionOptions, executor ) chainConfig = ChainConfig.create( mediaPipeLanguageModel, common.promptBuilder, // добавляем базу, в которой нужно проверять звпросы common.semanticMemory ) // делаем цепочку с проверкой в базе retrievalAndInferenceChain = RetrievalAndInferenceChain( chainConfig ) Futures.addCallback( mediaPipeLanguageModel!!.initialize(), object : FutureCallback<Boolean> { override fun onSuccess(result: Boolean) { logDebug("mediaPipeLanguageModel initialize onSuccess", TAG) } override fun onFailure(t: Throwable) { logError( "mediaPipeLanguageModel initialize onFailure: ${t.message}", t, TAG, ) } }, executor ) logDebug("startEngine ready", TAG) } override fun resetSession() { logDebug("resetSession", TAG) try { retrievalAndInferenceChain = RetrievalAndInferenceChain( chainConfig ) logDebug("Session reset completed", TAG) } catch (e: Exception) { logError("Failed to reset session: ${e.message}", e, TAG) } logDebug("resetSession ready", TAG) } override fun generateResponse( prompt: String, // Количество результатов запроса к базе topK: Int, // насколько вектор запроса query должен быть похожим на запись в базе // 0.0 = искать все записи, отсортировать по самым похожим // 1.0 = только идеальное совпадение // я использую значения 0.6 - 0.8 minSimilarityScore: Float, ): Flow<ResultEmittedData<String>> = callbackFlow { logDebug("generateResponse prompt: $prompt", TAG) try { if (retrievalAndInferenceChain == null) { logError("generateResponse retrievalAndInferenceChain == null", TAG) trySend( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "retrievalAndInferenceChain == null", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@callbackFlow } val retrievalConfig = RetrievalConfig.create( topK, minSimilarityScore, TaskType.QUESTION_ANSWERING ) // запрос уже включает цепочку с проверкой val retrievalRequest = RetrievalRequest.create( prompt, retrievalConfig ) logDebug("generateResponse retrievalRequest", TAG) val messageBuilder = StringBuilder() val listener = AsyncProgressListener<LanguageModelResponse> { partial, done -> val delta = partial.text.orEmpty() logDebug("generateResponse delta: $delta", TAG) if (!done && delta.isNotBlank()) { messageBuilder.append(delta) trySend( ResultEmittedData.loading( model = messageBuilder.toString(), ) ) } } future = retrievalAndInferenceChain!!.invoke( retrievalRequest, listener ) future?.addListener({ val fullText = future?.get()?.text if (fullText.isNullOrEmpty()) { logError("generateResponse fullText isNullOrEmpty", TAG) trySend( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Empty response", errorType = ErrorType.EXCEPTION ) ) close() return@addListener } logDebug("generateResponse fullText: $fullText", TAG) trySend( ResultEmittedData.success( model = fullText, message = null, responseCode = null ) ) close() }, executor) logDebug("generateResponse ready", TAG) } catch (t: Throwable) { logError("generateResponse failed: ${t.message}", t, TAG) trySend( ResultEmittedData.error( model = null, error = t, title = "MediaPipe engine error", responseCode = null, message = t.message, errorType = ErrorType.EXCEPTION, ) ) } } private fun createInterfaceOptions( modelFile: File, engineParams: MediaPipeEngineParams, isSupportImages: Boolean, ): LlmInferenceOptions { val backend = when (engineParams.backend) { MediaPipeBackendParams.CPU -> LlmInference.Backend.CPU MediaPipeBackendParams.GPU -> LlmInference.Backend.GPU } return LlmInferenceOptions.builder().apply { setModelPath(modelFile.absolutePath) setMaxTokens(engineParams.contextSize) setPreferredBackend(backend) val maxNumImages = if (isSupportImages) 1 else 0 setMaxNumImages(maxNumImages) if (engineParams.useMaxTopK) setMaxTopK(engineParams.maxTopK) }.build() } private fun createSessionOptions( engineParams: MediaPipeEngineParams, isSupportImages: Boolean, ): LlmInferenceSessionOptions { return LlmInferenceSessionOptions.builder().apply { if (engineParams.useTopK) setTopK(engineParams.topK) if (engineParams.useTopP) setTopP(engineParams.topP) if (engineParams.useTemperature) setTemperature(engineParams.temperature) if (engineParams.useRandomSeed) setRandomSeed(engineParams.randomSeed) setGraphOptions( GraphOptions.builder() .setEnableVisionModality(isSupportImages) .build() ) }.build() } private fun isInGeneration(): Boolean { return future != null && future?.isDone != true && future?.isCancelled != true } }
MediaPipeEngineWithTools:
import kotlinx.coroutines.flow.Flow import java.io.File interface MediaPipeEngineWithTools { fun startEngine( modelFile: File, isSupportImages: Boolean = false, engineParams: MediaPipeEngineParams, ) fun generateResponse( userQuery: String, maxSteps: Int = 3, ): Flow<ResultEmittedData<String>> }
package com.romankryvolapov.offlineailauncher.mediapipe import android.app.Application import com.google.ai.edge.localagents.core.proto.Content import com.google.ai.edge.localagents.core.proto.FunctionCall import com.google.ai.edge.localagents.core.proto.FunctionDeclaration import com.google.ai.edge.localagents.core.proto.FunctionResponse import com.google.ai.edge.localagents.core.proto.GenerateContentResponse import com.google.ai.edge.localagents.core.proto.Part import com.google.ai.edge.localagents.core.proto.Schema import com.google.ai.edge.localagents.core.proto.Tool import com.google.ai.edge.localagents.fc.GemmaFormatter import com.google.ai.edge.localagents.fc.GenerativeModel import com.google.ai.edge.localagents.fc.LlmInferenceBackend import com.google.ai.edge.localagents.rag.memory.VectorStoreRecord import com.google.ai.edge.localagents.rag.models.EmbedData import com.google.ai.edge.localagents.rag.models.EmbeddingRequest import com.google.common.collect.ImmutableList import com.google.mediapipe.tasks.genai.llminference.LlmInference import com.google.mediapipe.tasks.genai.llminference.LlmInference.LlmInferenceOptions import com.google.protobuf.Struct import com.google.protobuf.Value import com.romankryvolapov.offlineailauncher.common.models.common.ErrorType import com.romankryvolapov.offlineailauncher.common.models.common.LogUtil.logDebug import com.romankryvolapov.offlineailauncher.common.models.common.LogUtil.logError import com.romankryvolapov.offlineailauncher.common.models.common.ResultEmittedData import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.guava.await import java.io.File class MediaPipeEngineWithToolsImpl( private val application: Application, private val common: MediaPipeEngineCommon, ) : MediaPipeEngineWithTools { companion object { private const val TAG = "MediaPipeEngineWithToolsTag" private const val DEFAULT_MIN_SIMILARITY_SCORE = 0.8 private const val TOOLS_CODE = "tool_code" private const val RESULTS = "results" private const val TOOLS_ACTION_SEARCH_DOCS = "search_docs" private const val TOOLS_ACTION_SEARCH_DOCS_DESCRIPTION = "Searches knowledge and returns the most relevant results as plain text." private const val TOOLS_PARAM_QUERY = "query" private const val TOOLS_PARAM_QUERY_DESCRIPTION = "User query to search in the vector store." private const val TOOLS_PARAM_TOP_K = "top_k" private const val TOOLS_PARAM_TOP_K_DESCRIPTION = "Number of results to return (default 5)." private const val TOOLS_PARAM_MIN_SIMILARITY_SCORE = "min_similarity_score" private const val MIN_SIMILARITY_SCORE_DESCRIPTION = """ Minimum similarity score threshold (float) for filtering search results, from 0.0 (no filtering) to 1.0 (exact match). Start with $DEFAULT_MIN_SIMILARITY_SCORE, and if no results are found, lower the value and retry the search. """ // от этого темплейта зависит отчено много // неправильно подобранные параметры сделают вызов инструмента не возможным // или после вызова инструмента генерация остановится // для других LLM моделей темплейт может отличаться // ```tool_code хорошо работает для Gemma 3n, похоже она была на этом ключевом слове обучена // если через инструменты ничего не найдено, в промпте указано, что сходство можно уменьшить // инструменты могут быть совершенно разными, например SQL запрос, запрос в интернет, важно их правильно описать private val PROMPT_TEMPLATE_WITH_TOOLS = """ You are an on-device assistant. You have access to special tools (also called: "function call", "invoke tool", "use API", "search", "lookup", "query tool") If you decide to invoke any of the function, it should be wrapped with ```$TOOLS_CODE``` You have access to the following tools. * `$TOOLS_ACTION_SEARCH_DOCS`: Searches knowledge and returns the most relevant results as plain text. WHEN TO USE A TOOL - If you do not have enough information to answer with high confidence. - If the user explicitly or implicitly asks to check/verify/find out/look up ("check via tools", "verify", "lookup", etc.). Tool args: $TOOLS_PARAM_QUERY: User string query to search in the vector store. $TOOLS_PARAM_TOP_K: Integer number of results to return (default 5). $TOOLS_PARAM_MIN_SIMILARITY_SCORE: Minimum similarity score threshold (float) for filtering search results, from 0.0 (no filtering) to 1.0 (exact match). Start with $DEFAULT_MIN_SIMILARITY_SCORE, and if no results are found, lower the value and retry the search. Rules for tool call: ```$TOOLS_CODE $TOOLS_ACTION_SEARCH_DOCS($TOOLS_PARAM_QUERY="<string>", $TOOLS_PARAM_TOP_K=<integer>, $TOOLS_PARAM_MIN_SIMILARITY_SCORE=<float>) ``` Tool response: $RESULTS: Plain text results. IMPORTANT: After receiving tool results, ALWAYS write a natural-language answer for the user in the very next message. If tool results are empty, briefly explain that nothing relevant was found and propose next steps. """.trimIndent() } private var generativeModel: GenerativeModel? = null override fun startEngine( modelFile: File, isSupportImages: Boolean, engineParams: MediaPipeEngineParams, ) { logDebug("startEngine", TAG) val interfaceOptions = createInterfaceOptions( modelFile = modelFile, engineParams = engineParams, isSupportImages = isSupportImages, ) val engineMediaPipe = LlmInference.createFromOptions( application, interfaceOptions ) if (engineMediaPipe == null) { logError("startEngine llmInference == null", TAG) return } val searchDocs = FunctionDeclaration.newBuilder() .setName(TOOLS_ACTION_SEARCH_DOCS) .setDescription(TOOLS_ACTION_SEARCH_DOCS_DESCRIPTION) .setParameters( Schema.newBuilder() .setType(com.google.ai.edge.localagents.core.proto.Type.OBJECT) .putProperties( TOOLS_PARAM_QUERY, Schema.newBuilder() .setType(com.google.ai.edge.localagents.core.proto.Type.STRING) .setDescription(TOOLS_PARAM_QUERY_DESCRIPTION) .build() ) .putProperties( TOOLS_PARAM_TOP_K, Schema.newBuilder() .setType(com.google.ai.edge.localagents.core.proto.Type.INTEGER) .setDescription(TOOLS_PARAM_TOP_K_DESCRIPTION) .build() ) .putProperties( TOOLS_PARAM_MIN_SIMILARITY_SCORE, Schema.newBuilder() .setType(com.google.ai.edge.localagents.core.proto.Type.NUMBER) .setDescription(MIN_SIMILARITY_SCORE_DESCRIPTION) .build() ) .build() ) .build() val systemInstruction = Content.newBuilder() .setRole(Gemma3nRoles.SYSTEM.type) .addParts( Part.newBuilder().setText( PROMPT_TEMPLATE_WITH_TOOLS ) ) .build() val tool = Tool.newBuilder() .addFunctionDeclarations(searchDocs) .build() val inferenceBackend = LlmInferenceBackend( engineMediaPipe, GemmaFormatter() ) generativeModel = GenerativeModel( inferenceBackend, systemInstruction, listOf(tool), ) logDebug("startEngine ready", TAG) } override fun generateResponse( userQuery: String, maxSteps: Int, ): Flow<ResultEmittedData<String>> = flow { logDebug("generateResponseWithTools userQuery: $userQuery", TAG) try { val generativeModel = generativeModel ?: run { logError("generateResponseWithTools generativeModel is null", TAG) emit( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Model is not initialized;", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } val contentPart = Part.newBuilder() .setText(userQuery) .build() val userContent = Content.newBuilder() .setRole(Gemma3nRoles.USER.type) .addParts(contentPart) .build() val conversation = mutableListOf(userContent) var step = 0 // на всякий случай здесь есть цикл // модели отправляется запрос, если она считает, что нужно вызвать инструмент, // она пишет служебную информацию, инструмент вызывается и запрос с результатом повторяется // если вам нужно, чтобы модель пыталась найти лучший результат запроса, // меняя текст запроса или минимальное сходство, напишите об этом в промпте, чтобы // модель значала, что вызовов инструментов может быть много while (step < maxSteps) { logDebug("generateResponseWithTools step: $step conversation: ${conversation.size}", TAG) step++ val response: GenerateContentResponse = generativeModel.generateContent( conversation ) val responseContent: Content = response.candidatesList.firstOrNull()?.content ?: run { logError("generateResponseWithTools content is null", TAG) emit( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Candidates list is null", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } val functionCall: FunctionCall? = responseContent.partsList.firstOrNull { it.hasFunctionCall() }?.functionCall // если модель посчитала, что инструменты вызывать не нужно- просто отправляем ответ прользователю if (functionCall == null) { val text = extractText(response) if (text.isBlank()) { logError( "generateResponseWithTools text is blank, response: $response", TAG ) emit( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Empty text", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } logDebug("generateResponseWithTools functionCall is null text: $text", TAG) emit( ResultEmittedData.success( model = text, message = null, responseCode = null ) ) return@flow } if (functionCall.name != TOOLS_ACTION_SEARCH_DOCS) { logError("generateResponseWithTools wrong name: ${functionCall.name}", TAG) val text = extractText(response) if (text.isBlank()) { logError( "generateResponseWithTools text is blank, response: $response", TAG ) emit( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Wrong function call", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } emit( ResultEmittedData.success( model = text, message = null, responseCode = null ) ) return@flow } val args = functionCall.args.fieldsMap // модель возвращает в параметрах вызова инструмента текст запроса к базе, // количество результатов и сходство // если ничего не найдено, в промпте указано, что сходство можно уменьшить val query = args[TOOLS_PARAM_QUERY]?.stringValue val topK = args[TOOLS_PARAM_TOP_K]?.numberValue?.toInt() ?: 5 val minSimilarityScore = args[TOOLS_PARAM_MIN_SIMILARITY_SCORE]?.numberValue?.toFloat() ?: 0.0F if (query.isNullOrEmpty()) { logError("generateResponseWithTools query is null or empty", TAG) val text = extractText(response) if (text.isBlank()) { logError( "generateResponseWithTools text is blank, response: $response", TAG ) emit( ResultEmittedData.error( model = null, error = null, title = "MediaPipe engine error", responseCode = null, message = "Wrong function call", errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } logDebug("generateResponseWithTools query is null or empty text: $text", TAG) emit( ResultEmittedData.success( model = text, message = null, responseCode = null ) ) return@flow } val results: String = searchDocsInternal( query, topK, minSimilarityScore ) val respStruct = Struct.newBuilder() .putFields( TOOLS_PARAM_QUERY, Value.newBuilder().setStringValue(query).build() ) .putFields( TOOLS_PARAM_TOP_K, Value.newBuilder().setNumberValue(topK.toDouble()).build() ) .putFields( TOOLS_PARAM_MIN_SIMILARITY_SCORE, Value.newBuilder().setNumberValue(minSimilarityScore.toDouble()).build() ) .putFields( RESULTS, Value.newBuilder().setStringValue(results).build() ) .build() val functionResponse = FunctionResponse.newBuilder() .setName(TOOLS_ACTION_SEARCH_DOCS) .setResponse(respStruct) .build() val functionResponsePart = Part.newBuilder() .setFunctionResponse(functionResponse) .build() val toolContent = Content.newBuilder() .setRole(Gemma3nRoles.MODEL.type) .addParts(functionResponsePart) .build() // добавляем ответ модели с вызовом инструмента и сам вызов инструмента // в цепочку сообщений и запускаем следующую итерацию цикла // модель таким образом будет видеть все свои запросы и все результаты вызова инструмента conversation.add(responseContent) conversation.add(toolContent) logDebug("conversation: $conversation", TAG) if (step == maxSteps) { val finalResponse = generativeModel.generateContent(conversation) val text = extractText(finalResponse) if (text.isBlank()) { logError("generateResponseWithTools finalResponse text is blank", TAG) emit( ResultEmittedData.error( title = "MediaPipe engine error", message = "Empty final response", error = null, model = null, responseCode = null, errorType = ErrorType.ERROR_IN_LOGIC, ) ) return@flow } emit( ResultEmittedData.success( model = text, message = null, responseCode = null ) ) return@flow } } } catch (t: Throwable) { logError("generateResponseWithTools failed: ${t.message}", t, TAG) emit( ResultEmittedData.error( model = null, error = t, title = "MediaPipe engine error", responseCode = null, message = t.message, errorType = ErrorType.EXCEPTION, ) ) } } // поиск в векторной базе, при этом все параметры задает сама модель private suspend fun searchDocsInternal( query: String, topK: Int, minSimilarityScore: Float, ): String { logDebug("searchDocsInternal query: $query topK: $topK minSimilarityScore: $minSimilarityScore", TAG) val queryEmbedData: EmbedData<String> = EmbedData.create( query, EmbedData.TaskType.RETRIEVAL_QUERY ) val embeddingRequest: EmbeddingRequest<String> = EmbeddingRequest.create(listOf(queryEmbedData)) val vector: ImmutableList<Float> = try { common.embedder.getEmbeddings(embeddingRequest).await() } catch (t: Throwable) { logError( "searchDocsInternal: embedding failed: ${t.message}", t, TAG ) return "No results." } if (vector.isEmpty()) { logError("searchDocsInternal vector.isEmpty()", TAG) return "No results." } val hits: ImmutableList<VectorStoreRecord<String>> = try { common.vectorStore.getNearestRecords( vector, topK, minSimilarityScore ) } catch (t: Throwable) { logError("searchDocsInternal: failed: ${t.message}", t, TAG) return "No results." } if (hits.isEmpty()) { logError("searchDocsInternal hits.isEmpty()", TAG) return "No results." } val result = buildString { for (h in hits) { appendLine(h.data.trim()) } }.trim() logDebug("searchDocsInternal ready size: ${result.length}", TAG) return result } private fun extractText(response: GenerateContentResponse): String { response.candidatesList.forEach { candidate -> candidate.content.partsList.forEach { part -> if (part.text.isNotEmpty()) return part.text } } return "" } private fun createInterfaceOptions( modelFile: File, engineParams: MediaPipeEngineParams, isSupportImages: Boolean, ): LlmInferenceOptions { val backend = when (engineParams.backend) { MediaPipeBackendParams.CPU -> LlmInference.Backend.CPU MediaPipeBackendParams.GPU -> LlmInference.Backend.GPU } return LlmInferenceOptions.builder().apply { setModelPath(modelFile.absolutePath) setMaxTokens(engineParams.contextSize) setPreferredBackend(backend) val maxNumImages = if (isSupportImages) 1 else 0 setMaxNumImages(maxNumImages) if (engineParams.useMaxTopK) setMaxTopK(engineParams.maxTopK) }.build() } }
Надеюсь было интересно.
Кто захочет повторить, использованные вспомогательные классы:
enum class ErrorType { EXCEPTION, SERVER_ERROR, ERROR_IN_LOGIC, SERVER_DATA_ERROR, NO_INTERNET_CONNECTION, AUTHORIZATION } data class ResultEmittedData<out T>( val model: T?, val error: Any?, val status: Status, val title: String?, val message: String?, val responseCode: Int?, val errorType: ErrorType?, ) { enum class Status { SUCCESS, ERROR, LOADING, } companion object { fun <T> success( model: T, message: String?, responseCode: Int?, ): ResultEmittedData<T> = ResultEmittedData( error = null, title = null, model = model, errorType = null, message = message, status = Status.SUCCESS, responseCode = responseCode, ) fun <T> loading( model: T? = null, message: String? = null, ): ResultEmittedData<T> = ResultEmittedData( model = model, error = null, title = null, errorType = null, message = message, responseCode = null, status = Status.LOADING, ) fun <T> error( model: T?, error: Any?, title: String?, message: String?, responseCode: Int?, errorType: ErrorType?, ): ResultEmittedData<T> = ResultEmittedData( model = model, error = error, title = title, message = message, errorType = errorType, status = Status.ERROR, responseCode = responseCode, ) } } inline fun <T : Any> ResultEmittedData<T>.onLoading( action: ( model: T?, message: String?, ) -> Unit ): ResultEmittedData<T> { if (status == ResultEmittedData.Status.LOADING) action( model, message ) return this } inline fun <T : Any> ResultEmittedData<T>.onSuccess( action: ( model: T, message: String?, responseCode: Int?, ) -> Unit ): ResultEmittedData<T> { if (status == ResultEmittedData.Status.SUCCESS && model != null) action( model, message, responseCode, ) return this } inline fun <T : Any> ResultEmittedData<T>.onFailure( action: ( model: Any?, title: String?, message: String?, responseCode: Int?, errorType: ErrorType?, ) -> Unit ): ResultEmittedData<T> { if (status == ResultEmittedData.Status.ERROR) action( model, title, message, responseCode, errorType ) return this }
data class MediaPipeEngineParams( val name: String, val topK: Int, val topP: Float, val temperature: Float, val randomSeed: Int, val contextSize: Int, val maxTopK: Int, val useTopK: Boolean, val useTopP: Boolean, val useTemperature: Boolean, val useRandomSeed: Boolean, val useMaxTopK: Boolean, val backend: MediaPipeBackendParams, ) enum class MediaPipeBackendParams { CPU, GPU } fun Long.toDurationString(): String { var msRemaining = this val years = msRemaining / (365L * 24 * 60 * 60 * 1000) msRemaining %= (365L * 24 * 60 * 60 * 1000) val months = msRemaining / (30L * 24 * 60 * 60 * 1000) msRemaining %= (30L * 24 * 60 * 60 * 1000) val days = msRemaining / (24L * 60 * 60 * 1000) msRemaining %= (24L * 60 * 60 * 1000) val hours = msRemaining / (60L * 60 * 1000) msRemaining %= (60L * 60 * 1000) val minutes = msRemaining / (60L * 1000) msRemaining %= (60L * 1000) val seconds = msRemaining / 1000 val milliseconds = msRemaining % 1000 return buildString { if (years > 0) append("$years years, ") if (months > 0) append("$months months, ") if (days > 0) append("$days days, ") if (hours > 0) append("$hours hours, ") if (minutes > 0) append("$minutes minutes, ") if (seconds > 0) append("$seconds seconds, ") append("$milliseconds milliseconds") } }
import android.annotation.SuppressLint import android.os.Bundle import android.os.Environment import android.util.Log import com.google.firebase.analytics.FirebaseAnalytics import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.launch import org.koin.core.component.KoinComponent import org.koin.core.component.inject import java.io.File import java.io.FileOutputStream import java.text.SimpleDateFormat import java.util.Date object LogUtil : KoinComponent { private val timeDirectoryName: String private const val QUEUE_CAPACITY = 10000 private const val CURRENT_TAG = "LogUtilExecutionStatusTag" private const val LOG_APP_FOLDER_NAME = "app" private const val TIME_FORMAT_FOR_LOG = "HH:mm:ss dd-MM-yyyy" private const val TIME_FORMAT_FOR_DIRECTORY = "HH-mm-ss_dd-MM-yyyy" private const val TAG = "TAG: " private const val TIME = "TIME: " private const val ERROR_STACKTRACE = "ERROR STACKTRACE: " private const val ERROR_MESSAGE = "ERROR: " private const val DEBUG_MESSAGE = "MESSAGE: " private const val NEW_LINE = "\n" private val queue = ArrayDeque<LogData>(QUEUE_CAPACITY) private var saveLogsToTxtFileJob: Job? = null private val analytics: FirebaseAnalytics by inject() @Volatile private var isSaveLogsToTxtFile = false init { Log.d(CURRENT_TAG, "init") timeDirectoryName = getCurrentTimeForDirectory() } fun logDebug(message: String, tag: String) { CoroutineScope(Dispatchers.IO).launch { if (BuildConfig.DEBUG) { Log.d(tag, message) enqueue( LogData.DebugMessage( tag = tag, time = System.currentTimeMillis(), message = message, ) ) saveLogsToTxtFile() } } } fun logError(message: String, tag: String) { CoroutineScope(Dispatchers.IO).launch { if (BuildConfig.DEBUG) { Log.e(tag, message) enqueue( LogData.ErrorMessage( tag = tag, time = System.currentTimeMillis(), message = message, ) ) saveLogsToTxtFile() } } } fun logError(exception: Throwable, tag: String) { CoroutineScope(Dispatchers.IO).launch { if (BuildConfig.DEBUG) { Log.e(tag, exception.message, exception) enqueue( LogData.ExceptionMessage( tag = tag, time = System.currentTimeMillis(), exception = exception, ) ) saveLogsToTxtFile() } } } fun logError(message: String, exception: Throwable, tag: String) { CoroutineScope(Dispatchers.IO).launch { if (BuildConfig.DEBUG) { Log.e(tag, "$message, exception: ${exception.message}", exception) enqueue( LogData.ErrorMessageWithException( tag = tag, time = System.currentTimeMillis(), message = message, exception = exception, ) ) saveLogsToTxtFile() } } } fun logError(message: String, error: String?, tag: String) { CoroutineScope(Dispatchers.IO).launch { if (BuildConfig.DEBUG) { Log.e(tag, "$message, error: $error") enqueue( LogData.ErrorMessage( tag = tag, time = System.currentTimeMillis(), message = message, ) ) saveLogsToTxtFile() } } } @SuppressLint("SimpleDateFormat") private fun getTime(time: Long): String { return try { val date = Date(time) val timeString = SimpleDateFormat(TIME_FORMAT_FOR_LOG).format(date) timeString.ifEmpty { Log.e(CURRENT_TAG, "getTime time.ifEmpty") time.toString() } } catch (e: Exception) { Log.e(CURRENT_TAG, "getCurrentTime exception: ${e.message}", e) time.toString() } } @SuppressLint("SimpleDateFormat") private fun getCurrentTimeForDirectory(): String { val time = System.currentTimeMillis() return try { val date = Date(time) val timeString = SimpleDateFormat(TIME_FORMAT_FOR_DIRECTORY).format(date) Log.d(CURRENT_TAG, "getCurrentTimeForDirectory time: $time") timeString.ifEmpty { Log.e(CURRENT_TAG, "getCurrentTimeForDirectory time.ifEmpty") time.toString() } } catch (e: Exception) { Log.e(CURRENT_TAG, "getCurrentTimeForDirectory exception: ${e.message}", e) time.toString() } } private fun enqueue(message: LogData) { try { while (queue.size >= QUEUE_CAPACITY) { Log.d(CURRENT_TAG, "enqueue removeFirst") queue.removeFirst() } queue.addLast(message) } catch (e: Exception) { Log.e(CURRENT_TAG, "enqueue exception: ${e.message}", e) } } }