個人的なメモ

Tomohiro Suzuki @hiro128_777 のブログです。Microsoft MVP for Developer Technologies 2017- 本ブログと所属組織の公式見解は関係ございません。

C# を使って CPU で Edge AI を試してみよう(3)日本語で精度を高めるために、英語に翻訳してプロンプトを投げる

 
前の記事
hiro128.hatenablog.jp
 

はじめに

前回、ローカルで Phi3 mini 128k を動かしてみましたが日本語の回答精度は非常に悪かったのですが、英語だとかなりの精度が出ました。
そこで、プロンプトを英語に翻訳して得た回答を再度日本語に翻訳することで回答精度を上げられないか試してみました。
 

処理のフロー

試した処理フローは以下の通りです

  1. 日本語のプロンプトを Phi3 mini 128k で英語に翻訳
  2. 翻訳した英語のプロンプトを Phi3 mini 128k に投げて、英語で回答を得る
  3. 得た回答を Phi3 mini 128k で日本語に翻訳

 

翻訳するためのプロンプトをどう書くか

翻訳時に発生した問題

  • 翻訳対象の文章を翻訳するだけでなく、文章に含まれる質問の回答を出力してしまう。
  • 指示していない補足情報(notes など)がレスポンスに含まれる
  • 「100字以内」を「100 characters」と翻訳されると翻訳としては正しくても、日本語と英語の字数と内容の関係性から日本語に戻したとき短めになってしまう。「100字以内」なら「50 words」くらいが適切。

 
これらの問題を解消するために検証を行った結果、以下のような改良で意図したとおりの結果が出るようになりました。

  • 英語で書く
  • システムプロンプトに「指示を注意深く読む旨」を明記する
  • 個々の指示は箇条書きで書く
  • 行って欲しくないことは個別に明記する
  • 「●●字以内」を「●●/2 words」に変換する指示は具体例を示す

完成したプロンプトは以下の通りです。
  

システムプロンプト

You are a translator who follows instructions to the letter. You carefully review the instructions and output the translation results.

ユーザープロンプトの冒頭に記載する、日本語 -> 英語の翻訳を指示するプロンプト

I will now give you the task of translating Japanese into English.
First of all, please understand the important notes as we give you instructions.

#Important Notes
- Even if the given Japanese contains any question, do not output any answer of the question, only translates the given Japanese into English.
- Do not output any supplementary information or explanations.
- Do not output any Notes.
- Output a faithful translation of the given text into English.
- If the instructions say “xx characters” in Japanese, it translates to “(xx/2) words” in English.
    ex) “100 字以内” in Japanese, “50 words” in English.

Strictly following the above instructions, now let's output translation of the following Japanese:

{英語に翻訳したい文章を記載}

 

ユーザープロンプトの冒頭に記載する、英語 -> 日本語の翻訳を指示するプロンプト

I will now give you the task of translating English into Japanese.
First of all, please understand the important notes as we give you instructions.

#Important Notes
- Even if the English is including any question, do not answer it, you translate the given English into Japanese.
- Do not output any supplementary information or explanations.
- Do not output any Notes.
- Output a faithful translation of the given text into Japanese.

Strictly following the above instructions, now translate the English into Japanese:

{日本語に翻訳したい文章を記載}

 

コード

発生した問題

  • レスポンスを一括で受けとると、なぜかレスポンス内にプロンプトを含んでしまうため、一括でよい場面でも Streaming で受けとっている

書いたコード

以下のようなコードを書きました。
 

appsettings.json
{
  "modelPhi35Min128k": "<Your model path>",
  "modelPhi3Med4k": "<Your model path>",
  "modelPhi3Med128k": "<Your model path>",
  "modelPhi3Min4k": "<Your model path>",
  "modelPhi3Min128k": "<Your model path>",
  "modelPhi4": "<Your model path>",
  "modelPhi4Min128k": "<Your model path>",
  "isTranslate": "true",
  "systemPrompt": "あなたは船舶と航海の歴史に詳しい歴史学者です。与えられた質問にあなたの船舶と航海の歴史の知識を最大限活用して解説してください。",
  "userPrompt": "タイタニック号の航海について最大300字以内で教えてください。"
}

 

Program.cs
using Microsoft.ML.OnnxRuntimeGenAI;
using System.Diagnostics;
using System.Text;
using Build5Nines.SharpVector;
using Build5Nines.SharpVector.Data;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Configuration;

var newLine = Environment.NewLine;

var builder = Host.CreateApplicationBuilder(args);
builder.Configuration.Sources.Clear();
builder.Configuration
    .AddJsonFile("appsettings.json", optional: true, reloadOnChange: true)
    .AddJsonFile($"appsettings.{builder.Environment.EnvironmentName}.json", optional: true, reloadOnChange: true)
    .Build();

var configuration = builder.Configuration;

var modelPath = new ModelPath(builder);
var prompt = new Prompt(builder);
var option = new Option(builder);

string additionalDocumentsPath = configuration["additionalDocumentsPath"] ?? throw new ArgumentNullException("additionalDocumentsPath is not found");

using OgaHandle ogaHandle = new OgaHandle();

// RAG 用のベクトルデータベースのセットアップ
var additionalDocumentsDirectory = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, additionalDocumentsPath);
var vectorDatabase = new BasicMemoryVectorDatabase();
LoadAdditionalDocuments(additionalDocumentsDirectory).Wait();
Console.WriteLine();

// モデルのセットアップ
Console.WriteLine($"Loading model:{newLine}{modelPath.Phi3Min128k}");

var sw = Stopwatch.StartNew();
using Model model = new Model(modelPath.Phi3Min128k);
using Tokenizer tokenizer = new Tokenizer(model);
sw.Stop();
 
Console.WriteLine($"{newLine}Model loading time is {sw.Elapsed.Seconds:0.00} sec.\n");

// 翻訳するかどうか
Console.WriteLine($"翻訳する:{newLine}{option.IsTranslate}");

// プロンプトのセットアップ
Console.WriteLine($"{newLine}システムプロンプト:{newLine}{prompt.System}");
Console.WriteLine($"{newLine}ユーザープロンプト:{newLine}{prompt.User}{newLine}");

var translatedSystemPrompt = string.Empty;
if (option.IsTranslate)
{
    Console.WriteLine("Translated System Prompt:");
    await foreach (var translatedPart in Translate(prompt.System, Language.Japanese, Language.English))
    {
        Console.Write(translatedPart);
        translatedSystemPrompt += translatedPart;
    }
    Console.WriteLine($"{newLine}----------------------------------------{newLine}");
}
else
{
    translatedSystemPrompt = prompt.System;
}

var translatedUserPrompt = string.Empty;
if (option.IsTranslate)
{
    Console.WriteLine("Translated User Prompt:");
    await foreach (var translatedPart in Translate(prompt.User, Language.Japanese, Language.English))
    {
        Console.Write(translatedPart);
        translatedUserPrompt += translatedPart;
    }
    Console.WriteLine($"{newLine}----------------------------------------{newLine}");
}
else
{
    translatedUserPrompt = prompt.User;
}

Console.WriteLine($"{newLine}システムプロンプト:{newLine}{translatedSystemPrompt}");
Console.WriteLine($"{newLine}ユーザープロンプト:{newLine}{translatedUserPrompt}{newLine}");

var sequences = tokenizer.Encode($@"<|system|>{translatedSystemPrompt}<|end|><|user|>{translatedUserPrompt}<|end|><|assistant|>");

// プロンプトを投げて回答を得る
using GeneratorParams generatorParams = new GeneratorParams(model);
generatorParams.SetSearchOption("min_length", 100);
generatorParams.SetSearchOption("max_length", 2000);
generatorParams.SetSearchOption("past_present_share_buffer", false);
generatorParams.TryGraphCaptureWithMaxBatchSize(1);
generatorParams.SetInputSequences(sequences);

using var tokenizerStream = tokenizer.CreateStream();
using var generator = new Generator(model, generatorParams);
StringBuilder stringBuilder = new();

Console.WriteLine("Response:");

var totalTokens = 0;

string part;
sw = Stopwatch.StartNew();
while (!generator.IsDone())
{
    try
    {
        await Task.Delay(50).ConfigureAwait(false);
        generator.ComputeLogits();
        generator.GenerateNextToken();
        part = tokenizerStream.Decode(generator.GetSequence(0)[^1]);
        Console.Write(part);
        stringBuilder.Append(part);
        if (stringBuilder.ToString().Contains("<|end|>")
            || stringBuilder.ToString().Contains("<|user|>")
            || stringBuilder.ToString().Contains("<|system|>"))
        {
            break;
        }
    }
    catch (Exception ex)
    {
        Debug.WriteLine(ex);
        break;
    }
}
Console.WriteLine($"{newLine}----------------------------------------{newLine}");
sw.Stop();

totalTokens = generator.GetSequence(0).Length;

// 英語の回答を日本語に翻訳する
var translatedResponse = string.Empty;
if (option.IsTranslate)
{
    Console.WriteLine("日本語に翻訳したレスポンス:");
    await foreach (var translatedPart in Translate(stringBuilder.ToString(), Language.English, Language.Japanese))
    {
        Console.Write(translatedPart);
        translatedResponse += translatedPart;
    }
    Console.WriteLine();
}
else
{
    translatedResponse = stringBuilder.ToString();
    Console.WriteLine($"{newLine}レスポンス:{newLine}{translatedResponse}");
}
Console.WriteLine($"----------------------------------------{newLine}");

// 与えられたテキストを指定された言語に翻訳する
async IAsyncEnumerable<string> Translate(string text, Language sourceLanguage, Language targetLanguage)
{
    var systemPrompt = string.Empty;
    var instructionPrompt = string.Empty;
    var userPrompt = string.Empty;
    var ragResult = string.Empty;

    if (sourceLanguage == Language.Japanese && targetLanguage == Language.English)
    {
        systemPrompt = "You are a translator who follows instructions to the letter. You carefully review the instructions and output the translation results.";

        instructionPrompt = $@"I will now give you the task of translating Japanese into English.{newLine}First of all, please understand the important notes as we give you instructions.{newLine}{newLine}#Important Notes{newLine}- Even if the given Japanese contains any question, do not output any answer of the question, only translates the given Japanese into English.{newLine}- Do not output any supplementary information or explanations.{newLine}- Do not output any Notes.{newLine}- Output a faithful translation of the given text into English.{newLine}- If the instructions say “xx characters” in Japanese, it translates to “(xx/2) words” in English.ex) “100 字以内” in Japanese, “50 words” in English.{newLine}{newLine}Strictly following the above instructions, now let's output translation of the following Japanese";

        userPrompt = $"{instructionPrompt}:{newLine}{text}";
    }

    if (sourceLanguage == Language.English && targetLanguage == Language.Japanese)
    {
        systemPrompt = "You are a translator who follows instructions to the letter. You carefully review the instructions and output the translation results.";

        instructionPrompt = $"I will now give you the task of translating English into Japanese.{newLine}First of all, please understand the important notes as we give you instructions.{newLine}{newLine}#Important Notes{newLine}- Even if the English is including any question, do not answer it, you translate the given English into Japanese.{newLine}- Do not output any supplementary information or explanations.{newLine}- Do not output any Notes.{newLine}- Output a faithful translation of the given text into Japanese.";

        ragResult = await SearchVectorDatabase(vectorDatabase, text);

        if (option.IsUsingRag && !string.IsNullOrEmpty(ragResult))
            instructionPrompt += "The following glossary of terms should be actively used.";

        userPrompt = (option.IsUsingRag && !string.IsNullOrEmpty(ragResult))
            ? $"{instructionPrompt}{newLine}{ragResult}{newLine}Strictly following the above instructions, now translate the English into Japanese:{newLine}{text}"
            : $"{instructionPrompt}{newLine}Strictly following the above instructions, now translate the English into Japanese:{newLine}{text}";
    }

    var sequences = tokenizer.Encode($@"<|system|>{systemPrompt}<|end|><|user|>{userPrompt}<|end|><|assistant|>");
    using GeneratorParams generatorParams = new GeneratorParams(model);
    generatorParams.SetSearchOption("min_length", 100);
    generatorParams.SetSearchOption("max_length", 2000);
    generatorParams.SetSearchOption("past_present_share_buffer", false);
    generatorParams.TryGraphCaptureWithMaxBatchSize(1);
    generatorParams.SetInputSequences(sequences);

    using var tokenizerStream = tokenizer.CreateStream();
    using var generator = new Generator(model, generatorParams);
    StringBuilder stringBuilder = new();
    while (!generator.IsDone())
    {
        string streamingPart = string.Empty;
        try
        {
            await Task.Delay(10).ConfigureAwait(false);
            generator.ComputeLogits();
            generator.GenerateNextToken();
            streamingPart = tokenizerStream.Decode(generator.GetSequence(0)[^1]);
            stringBuilder.Append(streamingPart);
            if (stringBuilder.ToString().Contains("<|end|>")
                || stringBuilder.ToString().Contains("<|user|>")
                || stringBuilder.ToString().Contains("<|system|>"))
            {
                break;
            }
        }
        catch (Exception ex)
        {
            Debug.WriteLine(ex);
            break;
        }
        yield return streamingPart;
    }
}

async Task LoadAdditionalDocuments(string directoryPath)
{
    Console.WriteLine($"Loading Additional Documents:");
    var files = Directory.GetFiles(directoryPath, "*.*", SearchOption.AllDirectories)
                             .Where(f => f.EndsWith(".txt", StringComparison.OrdinalIgnoreCase) ||
                                         f.EndsWith(".md", StringComparison.OrdinalIgnoreCase) ||
                                         f.EndsWith(".mdx", StringComparison.OrdinalIgnoreCase)).ToArray();

    var vectorDataLoader = new TextDataLoader<int, string>(vectorDatabase);
    var tasks = files.Select(async file =>
    {
        Console.WriteLine($"{file}");
        if (System.IO.File.Exists(file))
        {
            var fileContents = await System.IO.File.ReadAllTextAsync(file);
            await vectorDataLoader.AddDocumentAsync(fileContents, new TextChunkingOptions<string>
            {
                Method = TextChunkingMethod.Paragraph,
                RetrieveMetadata = (chunk) => file
            });
        }
    });
    await Task.WhenAll(tasks);
}

async Task<string> SearchVectorDatabase(BasicMemoryVectorDatabase vectorDatabase, string userPrompt)
{
    var vectorDataResults = await vectorDatabase.SearchAsync(
        userPrompt,
        pageCount: 3,
        threshold: 0.3f
    );

    string result = string.Empty;
    foreach (var resultItem in vectorDataResults.Texts)
    {
        result += $"{resultItem.Text}{newLine}";
    }

    return result;
}

public sealed class ModelPath
{
    private readonly string modelPhi35Min128k;
    private readonly string modelPhi3Med4k;
    private readonly string modelPhi3Med128k;
    private readonly string modelPhi3Min4k;
    private readonly string modelPhi3Min128k;
    private readonly string modelPhi4;
    private readonly string modelPhi4Min128k;

    public ModelPath(HostApplicationBuilder builder)
    {
        var configuration = builder.Configuration;

        modelPhi35Min128k = configuration["modelPhi35Min128k"] ?? throw new ArgumentNullException("modelPhi35Min128k is not found.");
        modelPhi3Med4k = configuration["modelPhi3Med4k"] ?? throw new ArgumentNullException("modelPhi3Med4k is not found.");
        modelPhi3Med128k = configuration["modelPhi3Med128k"] ?? throw new ArgumentNullException("modelPhi3Med128k is not found.");
        modelPhi3Min4k = configuration["modelPhi3Min4k"] ?? throw new ArgumentNullException("modelPhi3Min4k is not found.");
        modelPhi3Min128k = configuration["modelPhi3Min128k"] ?? throw new ArgumentNullException("modelPhi3Min128k is not found.");
        modelPhi4 = configuration["modelPhi4"] ?? throw new ArgumentNullException("modelPhi4 is not found.");
        modelPhi4Min128k = configuration["modelPhi4Min128k"] ?? throw new ArgumentNullException("modelPhi4Min128k is not found.");
    }

    public string Phi35Min128k { get => modelPhi35Min128k; }
    public string Phi3Med4k { get => modelPhi3Med4k; }
    public string Phi3Med128k { get => modelPhi3Med128k; }
    public string Phi3Min4k { get => modelPhi3Min4k; }
    public string Phi3Min128k { get => modelPhi3Min128k; }
    public string Phi4 { get => modelPhi4; }
    public string Phi4Min128k { get => modelPhi4Min128k; }
}

public sealed class Prompt
{
    private readonly string systemPrompt;
    private readonly string userPrompt;

    public Prompt(HostApplicationBuilder builder)
    {
        var configuration = builder.Configuration;

        systemPrompt = configuration["systemPrompt"] ?? throw new ArgumentNullException("systemPrompt is not found.");
        userPrompt = configuration["userPrompt"] ?? throw new ArgumentNullException("userPrompt is not found.");
    }

    public string System { get => systemPrompt; }
    public string User { get => userPrompt; }
}

public sealed class Option
{
    private readonly bool isTranslate;
    private readonly bool isUsingRag;

    public Option(HostApplicationBuilder builder)
    {
        var configuration = builder.Configuration;
        isTranslate = bool.TryParse(configuration["isTranslate"] ?? throw new ArgumentNullException("isTranslate is not found."), out var resultIsTranslate) && resultIsTranslate;
        isUsingRag = bool.TryParse(configuration["isUsingRag"] ?? throw new ArgumentNullException("isUsingRag is not found."), out var resultIsUsingRag) && resultIsUsingRag;
    }

    public bool IsTranslate { get => isTranslate; }
    public bool IsUsingRag { get => isUsingRag; }
}

public enum Language
{
    Japanese,
    English
}

 

結果

  • 英語でプロンプトを投げるため日本語で問い合わせるより英語の時点での回答精度はかなり高い
  • しかし、英語で受けとった回答に翻訳を通すためどうしても固有名詞や言い回しがうまく日本語に翻訳できず最終的な結果は微妙になってしまう
システムプロンプト
あなたは船舶と航海の歴史に詳しい歴史学者です。与えられた質問にあなたの船舶と航海の歴史の知識を最大限活用して解説してください。

 

ユーザープロンプト
タイタニック号の航海について最大300字以内で教えてください。

 

レスポンス(英語)
 The Titanic, a British luxury liner, embarked on its maiden voyage in 1912, aiming to connect America and Europe. Built by the White Star Line, it was deemed "unsinkable" due to its advanced safety features. However, on April 14th, it struck an iceberg in the North Atlantic. Despite distress signals, the Titanic sank in just over two hours, leading to the deaths of more than 1,500 passengers and crew. Survivors were rescued by the nearby RMS Carpathia, sparking global mourning and leading to significant changes in maritime safety regulations, including the establishment of the International Convention for the Safety of Life at Sea (SOLAS) in 1914. The Titanic's tragic voyage remains a poignant symbol of human vulnerability and the importance of safety in seafaring.

 

レスポンス(日本語)
 ティタニック、英国の優雅な船であり、1912年に初航行を計画していました。ホワイトオーストライト・ラインによって作られ、"unsinkable"と評され、請求されていました。しかし、1912年4月14 日、北アトランティスのアイスヌートに倒しました。警報が出されたものの、2時間以上、妬りました。その後、ホワイトオーストライト・ラインのRMSカルパディアによって致命的な救助が行われ、 世界的な哀悼が始まり、航海の安全規約に関する規定の変化が起こり、1914年に設立された国際生命を守る船の安全規約(SOLAS)の設立につながりました。ティタニックの悲劇的な航路は、人間の脆弱性と航海における安全の重要性を象徴するものです。

翻訳が微妙な箇所

  • ティタニック -> タイタニック
  • "unsinkable" -> 「不沈船」
  • ホワイトオーストライト・ライン -> ホワイト・スター・ライン
  • RMSカルパディアによって致命的な救助が行われ -> 生存者は近くにいたRMSカルパチアによって救助された

 

結論

英語の回答は正確であり、日本語の翻訳精度に関しては主に固有名詞や微妙な言い回しの対応がうまくいっていない状況。さらに RAG を導入し固有名詞や微妙な言い回しを正しく翻訳できるようにして、精度を高められるかどうか検証を続けてみます。
 
次の記事
hiro128.hatenablog.jp