Skip to content

var results = session.run(inputMap) 报错 Input Missing Input: attention_mask #1

@5-49

Description

@5-49

按照示例代码运行结果在控制台得到报错:2024-05-07 13:22:48.6268951 [E:onnxruntime:, sequential_executor.cc:514 onnxruntime::ExecuteKernel] Non-zero status code returned while running Unsqueeze node. Name:'/bert/Unsqueeze' Status Message: C:\a_work\1\s\include\onnxruntime\core/framework/op_kernel_context.h:42 onnxruntime::OpKernelContext::Input Missing Input: attention_mask

Exception in thread "main" ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Unsqueeze node. Name:'/bert/Unsqueeze' Status Message: C:\a_work\1\s\include\onnxruntime\core/framework/op_kernel_context.h:42 onnxruntime::OpKernelContext::Input Missing Input: attention_mask

请问可能是什么原因引起的呢?

以下是我的全部代码:

package org.jadestudio;

import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

import java.util.Arrays;

//"C:\\Users\\eess6\\Desktop\\rico-process\\bert-base-uncased\\model.onnx"

public class OnnxTests {
    public static void main(String[] args) throws Exception {
        // 加载 ONNX 模型
        String modelPath = "C:\\Users\\eess6\\Desktop\\rico-process\\bert-base-uncased\\model.onnx";
        String vocabPath = "C:\\Users\\eess6\\Desktop\\rico-process\\bert-base-uncased\\vocab.txt";
         // 初始化环境和会话
        var env = OrtEnvironment.getEnvironment();
        var session = env.createSession(modelPath, new OrtSession.SessionOptions());
 
         // 创建分词器
         BertTokenizer bertTokenizer = new BertTokenizer(vocabPath);
 
         // 生成ONNX Tensor
         var inputMap = bertTokenizer.tokenizeOnnxTensor(Arrays.asList("I like apple", "Apple is good"));
 
         try {
             // 运行模型
             try (var results = session.run(inputMap)) {
                 var embeddings = (float[][]) results.get(0).getValue();
 
                 // 提取嵌入向量
                 float[] embedding1 = embeddings[0];
                 float[] embedding2 = embeddings[1];
 
                 // 计算余弦相似度
                 double cosineSimilarity = cosineSimilarity(embedding1, embedding2);
                 System.out.println("Cosine Similarity: " + cosineSimilarity);
             }
         } finally {
             // 清理资源
             session.close();
             env.close();
         }
    }

    // 计算余弦相似度的辅助函数
    private static double cosineSimilarity(float[] vectorA, float[] vectorB) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < vectorA.length; i++) {
            dotProduct += vectorA[i] * vectorB[i];
            normA += Math.pow(vectorA[i], 2);
            normB += Math.pow(vectorB[i], 2);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions