Riva ASR Go Client and Mock Server

这个项目提供了 NVIDIA Riva 自动语音识别 (ASR) 服务的 Go 客户端库和模拟服务器实现。

项目结构

riva-proxy-go/
├── proto/                  # Protocol Buffers 定义和生成的 Go 代码
│   ├── riva_asr.proto     # Riva ASR 服务的 protobuf 定义
│   ├── riva_asr.pb.go     # 生成的 protobuf Go 代码
│   └── riva_asr_grpc.pb.go # 生成的 gRPC Go 代码
├── client/                 # Go 客户端库
│   └── client.go          # Riva ASR 客户端实现
├── server/                 # Mock 服务器实现
│   └── mock_server.go     # 模拟 Riva ASR 服务器
├── examples/               # 示例代码
│   ├── client_demo/       # 客户端使用示例
│   │   └── main.go
│   └── server_demo/       # 服务器启动示例
│       └── main.go
├── go.mod                 # Go 模块定义
└── README.md              # 项目文档

功能特性

客户端功能

  • 非流式识别: 一次性发送完整音频数据进行识别
  • 流式识别: 实时音频流识别,支持中间结果
  • 配置灵活: 支持多种音频编码格式和识别参数
  • 连接管理: 自动处理 gRPC 连接和错误重试
  • TLS 支持: 可选的安全连接

服务器功能

  • 完整的 Mock 实现: 实现了 Riva ASR 的所有 gRPC 接口
  • 模拟真实响应: 生成带有置信度、时间戳的识别结果
  • 流式处理: 支持实时音频流处理和中间结果
  • 可配置响应: 可以自定义模拟的识别结果
  • 日志记录: 详细的请求和响应日志

快速开始

安装依赖

go mod tidy

启动 Mock 服务器

cd examples/server_demo
go run main.go

服务器将在 localhost:50051 启动。

运行客户端示例

在另一个终端中:

cd examples/client_demo
go run main.go

使用方法

客户端使用

package main

import (
    "context"
    "log"
    "time"
    
    "riva-proxy-go/client"
    pb "riva-proxy-go/proto"
)

func main() {
    // 创建客户端配置
    config := client.ClientConfig{
        ServerAddress: "localhost:50051",
        UseTLS:        false,
        Timeout:       30 * time.Second,
    }

    // 创建客户端
    rivaClient, err := client.NewRivaClient(config)
    if err != nil {
        log.Fatalf("Failed to create client: %v", err)
    }
    defer rivaClient.Close()

    ctx := context.Background()

    // 非流式识别
    recognitionConfig := client.NewBasicRecognitionConfig(
        pb.RecognitionConfig_LINEAR_PCM,
        16000, // 16kHz 采样率
        "en-US",
    )

    audioData := []byte{} // 你的音频数据
    resp, err := rivaClient.Recognize(ctx, recognitionConfig, audioData)
    if err != nil {
        log.Printf("Recognition failed: %v", err)
        return
    }

    // 处理结果
    for _, result := range resp.Results {
        for _, alternative := range result.Alternatives {
            log.Printf("识别结果: %s (置信度: %.2f)", 
                alternative.Transcript, alternative.Confidence)
        }
    }
}

流式识别

// 创建流式配置
streamingConfig := client.NewStreamingRecognitionConfig(recognitionConfig, true)

// 开始流式会话
session, err := rivaClient.StreamingRecognize(ctx, streamingConfig)
if err != nil {
    log.Fatalf("Failed to create streaming session: %v", err)
}
defer session.Close()

// 发送音频数据
go func() {
    for {
        audioChunk := getNextAudioChunk() // 获取音频块
        if err := session.SendAudio(audioChunk); err != nil {
            log.Printf("Failed to send audio: %v", err)
            return
        }
    }
}()

// 接收识别结果
session.ReceiveAll(func(resp *pb.StreamingRecognizeResponse) error {
    for _, result := range resp.Results {
        for _, alternative := range result.Alternatives {
            log.Printf("流式结果: %s (最终: %t)", 
                alternative.Transcript, result.IsFinal)
        }
    }
    return nil
})

启动 Mock 服务器

package main

import (
    "log"
    "riva-proxy-go/server"
)

func main() {
    config := server.ServerConfig{
        Port:   50051,
        UseTLS: false,
    }

    log.Println("启动 Mock 服务器...")
    if err := server.StartMockServer(config); err != nil {
        log.Fatalf("服务器启动失败: %v", err)
    }
}

API 参考

客户端 API

NewRivaClient(config ClientConfig) (*RivaClient, error)

创建新的 Riva 客户端实例。

Recognize(ctx context.Context, config *pb.RecognitionConfig, audioData []byte) (*pb.RecognizeResponse, error)

执行非流式语音识别。

StreamingRecognize(ctx context.Context, config *pb.StreamingRecognitionConfig) (*StreamingSession, error)

开始流式语音识别会话。

服务器 API

StartMockServer(config ServerConfig) error

启动 Mock gRPC 服务器。

NewMockRivaServer() *MockRivaServer

创建新的 Mock 服务器实例。

配置选项

客户端配置 (ClientConfig)

  • ServerAddress: 服务器地址 (例如: "localhost:50051")
  • UseTLS: 是否使用 TLS 连接
  • Timeout: 连接超时时间

服务器配置 (ServerConfig)

  • Port: 监听端口
  • UseTLS: 是否启用 TLS
  • CertFile: TLS 证书文件路径
  • KeyFile: TLS 私钥文件路径

识别配置 (RecognitionConfig)

  • Encoding: 音频编码格式 (LINEAR_PCM, FLAC, MULAW, ALAW, LINEAR16)
  • SampleRateHertz: 采样率 (例如: 16000)
  • LanguageCode: 语言代码 (例如: "en-US", "zh-CN")
  • MaxAlternatives: 最大候选结果数
  • EnableAutomaticPunctuation: 启用自动标点
  • EnableWordTimeOffsets: 启用单词时间戳
  • DiarizationConfig: 说话人分离配置

测试

运行所有测试:

go test ./...

运行特定包的测试:

go test ./client
go test ./server

开发

重新生成 protobuf 代码

如果修改了 .proto 文件,需要重新生成 Go 代码:

protoc --go_out=. --go_opt=paths=source_relative \
       --go-grpc_out=. --go-grpc_opt=paths=source_relative \
       proto/riva_asr.proto

添加新功能

  1. 修改相应的 .go 文件
  2. 添加测试用例
  3. 更新文档
  4. 运行测试确保功能正常

许可证

本项目采用 MIT 许可证。详见 LICENSE 文件。

贡献

欢迎提交 Issue 和 Pull Request!

相关链接