Binding kaldi-native-fbank for kotlin
This note shows how to bind https://github.com/csukuangfj/kaldi-native-fbank for kotlin.
./code/jni/feat-extractor/Makefile
SHERPA_NCNN_INSTALL_DIR := /ceph-fj/fangjun/open-source/sherpa-ncnn/build/install
CXXFLAGS := -I $(JAVA_HOME)/include
CXXFLAGS += -I $(JAVA_HOME)/include/linux
CXXFLAGS += -I $(SHERPA_NCNN_INSTALL_DIR)/include
CXXFLAGS += -I $(SHERPA_NCNN_INSTALL_DIR)/include/ncnn
CXXFLAGS += -Wall
LDFLAGS := -L $(SHERPA_NCNN_INSTALL_DIR)/lib -lkaldi-native-fbank-core -lsherpa-ncnn-core -lncnn
LDFLAGS += -Wl,-rpath,$(SHERPA_NCNN_INSTALL_DIR)/lib
all: main.jar libsherpa-ncnn.so
main.jar: Main.kt OnlineFeature.kt WaveReader.kt Model.kt
kotlinc-jvm -include-runtime -d main.jar Main.kt OnlineFeature.kt WaveReader.kt Model.kt
libsherpa-ncnn.so: online-feature.cc sherpa-ncnn.cc online-feature.h sherpa-ncnn.h
$(CXX) -o $@ -shared -fPIC $(CXXFLAGS) online-feature.cc sherpa-ncnn.cc $(LDFLAGS)
run: all
java -jar main.jar
clean:
$(RM) main.jar libsherpa-ncnn.so
./code/jni/feat-extractor/online-feature.h
#ifndef ONLINE_FEATURE_H_
#define ONLINE_FEATURE_H_
#include "jni.h"
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jlong JNICALL Java_OnlineFbank_new(JNIEnv *env, jobject obj,
jobject opts);
JNIEXPORT void JNICALL Java_OnlineFbank_delete(JNIEnv *env, jobject obj,
jlong ptr);
JNIEXPORT jint JNICALL Java_OnlineFbank_dim(JNIEnv *env, jobject obj,
jlong ptr);
JNIEXPORT jfloat JNICALL Java_OnlineFbank_frameShiftInSeconds(JNIEnv *env,
jobject obj,
jlong ptr);
JNIEXPORT jint JNICALL Java_OnlineFbank_numFramesReady(JNIEnv *env, jobject obj,
jlong ptr);
JNIEXPORT jboolean JNICALL Java_OnlineFbank_isLastFrame(JNIEnv *env,
jobject obj, jlong ptr,
jint i);
JNIEXPORT void JNICALL Java_OnlineFbank_inputFinished(JNIEnv *env, jobject obj,
jlong ptr);
JNIEXPORT void JNICALL Java_OnlineFbank_acceptWaveform(JNIEnv *env, jobject obj,
jlong ptr,
jfloatArray samples,
jfloat sample_rate);
JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrame(JNIEnv *env,
jobject obj, jlong ptr,
jint i);
JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrames(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jint start,
jint n);
#ifdef __cplusplus
}
#endif
#endif // ONLINE_FEATURE_H_
./code/jni/feat-extractor/online-feature.cc
#include "online-feature.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
JNIEXPORT jlong JNICALL Java_OnlineFbank_new(JNIEnv *env, jobject /*obj*/,
jobject opts) {
jclass cls = env->GetObjectClass(opts);
jfieldID fid;
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
knf::FbankOptions fbank_opts;
fid = env->GetFieldID(cls, "use_energy", "Z");
fbank_opts.use_energy = env->GetBooleanField(opts, fid);
fid = env->GetFieldID(cls, "energy_floor", "F");
fbank_opts.energy_floor = env->GetFloatField(opts, fid);
fid = env->GetFieldID(cls, "raw_energy", "Z");
fbank_opts.raw_energy = env->GetBooleanField(opts, fid);
fid = env->GetFieldID(cls, "htk_compat", "Z");
fbank_opts.htk_compat = env->GetBooleanField(opts, fid);
fid = env->GetFieldID(cls, "use_log_fbank", "Z");
fbank_opts.use_log_fbank = env->GetBooleanField(opts, fid);
fid = env->GetFieldID(cls, "use_power", "Z");
fbank_opts.use_power = env->GetBooleanField(opts, fid);
fid = env->GetFieldID(cls, "frame_opts", "LFrameExtractionOptions;");
jobject frame_opts = env->GetObjectField(opts, fid);
jclass frame_opts_cls = env->GetObjectClass(frame_opts);
fid = env->GetFieldID(frame_opts_cls, "samp_freq", "F");
fbank_opts.frame_opts.samp_freq = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "frame_shift_ms", "F");
fbank_opts.frame_opts.frame_shift_ms = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "frame_length_ms", "F");
fbank_opts.frame_opts.frame_length_ms = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "dither", "F");
fbank_opts.frame_opts.dither = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "preemph_coeff", "F");
fbank_opts.frame_opts.preemph_coeff = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "remove_dc_offset", "Z");
fbank_opts.frame_opts.remove_dc_offset =
env->GetBooleanField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "window_type", "Ljava/lang/String;");
jstring window_type = (jstring)env->GetObjectField(frame_opts, fid);
const char *p_window_type = env->GetStringUTFChars(window_type, nullptr);
fbank_opts.frame_opts.window_type = p_window_type;
env->ReleaseStringUTFChars(window_type, p_window_type);
fid = env->GetFieldID(frame_opts_cls, "round_to_power_of_two", "Z");
fbank_opts.frame_opts.round_to_power_of_two =
env->GetBooleanField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "blackman_coeff", "F");
fbank_opts.frame_opts.blackman_coeff = env->GetFloatField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "snip_edges", "Z");
fbank_opts.frame_opts.snip_edges = env->GetBooleanField(frame_opts, fid);
fid = env->GetFieldID(frame_opts_cls, "max_feature_vectors", "I");
fbank_opts.frame_opts.max_feature_vectors = env->GetIntField(frame_opts, fid);
fid = env->GetFieldID(cls, "mel_opts", "LMelBanksOptions;");
jobject mel_opts = env->GetObjectField(opts, fid);
jclass mel_opts_cls = env->GetObjectClass(mel_opts);
fid = env->GetFieldID(mel_opts_cls, "num_bins", "I");
fbank_opts.mel_opts.num_bins = env->GetIntField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "low_freq", "F");
fbank_opts.mel_opts.low_freq = env->GetFloatField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "high_freq", "F");
fbank_opts.mel_opts.high_freq = env->GetFloatField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "vtln_low", "F");
fbank_opts.mel_opts.vtln_low = env->GetFloatField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "vtln_high", "F");
fbank_opts.mel_opts.vtln_high = env->GetFloatField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "debug_mel", "Z");
fbank_opts.mel_opts.debug_mel = env->GetBooleanField(mel_opts, fid);
fid = env->GetFieldID(mel_opts_cls, "htk_mode", "Z");
fbank_opts.mel_opts.htk_mode = env->GetBooleanField(mel_opts, fid);
auto online_fbank = new knf::OnlineFbank(fbank_opts);
return (jlong)online_fbank;
}
JNIEXPORT void JNICALL Java_OnlineFbank_delete(JNIEnv *env, jobject obj,
jlong ptr) {
delete reinterpret_cast<knf::OnlineFbank *>(ptr);
}
JNIEXPORT jint JNICALL Java_OnlineFbank_dim(JNIEnv *env, jobject obj,
jlong ptr) {
return reinterpret_cast<const knf::OnlineFbank *>(ptr)->Dim();
}
JNIEXPORT jfloat JNICALL Java_OnlineFbank_frameShiftInSeconds(JNIEnv *env,
jobject obj,
jlong ptr) {
return reinterpret_cast<const knf::OnlineFbank *>(ptr)->FrameShiftInSeconds();
}
JNIEXPORT jint JNICALL Java_OnlineFbank_numFramesReady(JNIEnv *env, jobject obj,
jlong ptr) {
return reinterpret_cast<const knf::OnlineFbank *>(ptr)->NumFramesReady();
}
JNIEXPORT jboolean JNICALL Java_OnlineFbank_isLastFrame(JNIEnv *env,
jobject obj, jlong ptr,
jint i) {
return reinterpret_cast<const knf::OnlineFbank *>(ptr)->IsLastFrame(i);
}
JNIEXPORT void JNICALL Java_OnlineFbank_inputFinished(JNIEnv *env, jobject obj,
jlong ptr) {
reinterpret_cast<knf::OnlineFbank *>(ptr)->InputFinished();
}
JNIEXPORT void JNICALL Java_OnlineFbank_acceptWaveform(JNIEnv *env, jobject obj,
jlong ptr,
jfloatArray samples,
jfloat sample_rate) {
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
reinterpret_cast<knf::OnlineFbank *>(ptr)->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrame(JNIEnv *env,
jobject obj, jlong ptr,
jint i) {
auto online_fbank = reinterpret_cast<const knf::OnlineFbank *>(ptr);
auto frame = online_fbank->GetFrame(i);
auto dim = online_fbank->Dim();
jfloatArray ans = env->NewFloatArray(dim);
env->SetFloatArrayRegion(ans, 0, dim, frame);
return ans;
}
JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrames(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jint start,
jint n) {
auto online_fbank = reinterpret_cast<const knf::OnlineFbank *>(ptr);
auto dim = online_fbank->Dim();
if (start + n > online_fbank->NumFramesReady()) {
return nullptr;
}
jfloatArray ans = env->NewFloatArray(n * dim);
for (int32_t i = 0; i != n; ++i) {
auto frame = online_fbank->GetFrame(start + i);
env->SetFloatArrayRegion(ans, i * dim, dim, frame);
}
return ans;
}
./code/jni/feat-extractor/sherpa-ncnn.h
#ifndef SHERPA_NCNN_H_
#define SHERPA_NCNN_H_
#include "jni.h"
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jfloatArray JNICALL Java_WaveReader_00024Companion_readWave(
JNIEnv *env, jclass cls, jstring filename, jfloat expected_sample_rate);
JNIEXPORT jlong JNICALL Java_Model_new(JNIEnv *env, jobject /*obj*/,
jobject config);
JNIEXPORT void JNICALL Java_Model_delete(JNIEnv *env, jobject /*obj*/,
jlong ptr);
JNIEXPORT jint JNICALL Java_Model_segment(JNIEnv *env, jobject /*obj*/,
jlong ptr);
JNIEXPORT jint JNICALL Java_Model_offset(JNIEnv *env, jobject /*obj*/,
jlong ptr);
#ifdef __cplusplus
}
#endif
#endif // SHERPA_NCNN_H_
./code/jni/feat-extractor/sherpa-ncnn.cc
#include "sherpa-ncnn.h"
#include "sherpa-ncnn/csrc/model.h"
#include "sherpa-ncnn/csrc/wave-reader.h"
#include <iostream>
JNIEXPORT jfloatArray JNICALL Java_WaveReader_00024Companion_readWave(
JNIEnv *env, jclass cls, jstring filename, jfloat expected_sample_rate) {
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
bool is_ok = false;
std::vector<float> samples =
sherpa_ncnn::ReadWave(p_filename, expected_sample_rate, &is_ok);
env->ReleaseStringUTFChars(filename, p_filename);
if (!is_ok) {
return nullptr;
}
jfloatArray ans = env->NewFloatArray(samples.size());
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
return ans;
}
JNIEXPORT jlong JNICALL Java_Model_new(JNIEnv *env, jobject /*obj*/,
jobject config) {
sherpa_ncnn::ModelConfig model_config;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "encoderParam", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
model_config.encoder_param = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "encoderBin", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
model_config.encoder_bin = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "decoderParam", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
model_config.decoder_param = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "decoderBin", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
model_config.decoder_bin = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "joinerParam", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
model_config.joiner_param = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "joinerBin", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
model_config.joiner_bin = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
model_config.num_threads = env->GetIntField(config, fid);
std::cout << model_config.ToString() << "\n";
auto model = sherpa_ncnn::Model::Create(model_config);
return (jlong)(model.release());
}
JNIEXPORT void JNICALL Java_Model_delete(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_ncnn::Model *>(ptr);
}
JNIEXPORT jint JNICALL Java_Model_segment(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
return reinterpret_cast<const sherpa_ncnn::Model *>(ptr)->Segment();
}
JNIEXPORT jint JNICALL Java_Model_offset(JNIEnv *env, jobject /*obj*/,
jlong ptr) {
return reinterpret_cast<const sherpa_ncnn::Model *>(ptr)->Offset();
}
./code/jni/feat-extractor/Main.kt
fun main() {
var fbank_opts = FbankOptions()
fbank_opts.mel_opts.num_bins = 80
var online_fbank = OnlineFbank(fbank_opts)
var samples = WaveReader.readWave("./1089-134686-0001.wav", 16000.0f)
if(samples != null) {
online_fbank.acceptWaveform(samples)
}
online_fbank.inputFinished()
var modelConfig = ModelConfig(
encoderParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
encoderBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
decoderParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
decoderBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
joinerParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
joinerBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
numThreads=4,
)
var model = Model(modelConfig)
println("segment: ${model.segment}")
println("offset: ${model.offset}")
}
./code/jni/feat-extractor/OnlineFeature.kt
data class FrameExtractionOptions(
var samp_freq: Float = 16000.0f,
var frame_shift_ms: Float = 10.0f,
var frame_length_ms: Float = 25.0f,
var dither: Float = 0.0f,
var preemph_coeff: Float = 0.97f,
var remove_dc_offset: Boolean = true,
var window_type: String = "povey",
var round_to_power_of_two: Boolean = true,
var blackman_coeff: Float = 0.42f,
var snip_edges: Boolean = true,
var max_feature_vectors: Int = -1
)
data class MelBanksOptions(
var num_bins : Int = 25,
var low_freq : Float = 20.0f,
var high_freq : Float = 0.0f,
var vtln_low : Float = 100.0f,
var vtln_high : Float = -500.0f,
var debug_mel : Boolean = false,
var htk_mode : Boolean = false,
)
data class FbankOptions(
var frame_opts: FrameExtractionOptions = FrameExtractionOptions(),
var mel_opts: MelBanksOptions = MelBanksOptions(),
var use_energy: Boolean = false,
var energy_floor: Float = 0.0f,
var raw_energy: Boolean = true,
var htk_compat: Boolean = false,
var use_log_fbank: Boolean = true,
var use_power: Boolean = true,
)
class OnlineFbank(var opts: FbankOptions) {
private var ptr: Long
init {
ptr = new(opts)
}
protected fun finalize() {
delete(ptr)
}
val dim: Int
get() = dim(ptr)
val frameShiftInSeconds: Float
get() = frameShiftInSeconds(ptr)
val numFramesReady: Int
get() = numFramesReady(ptr)
fun isLastFrame(i: Int) :Boolean = isLastFrame(ptr, i)
fun inputFinished() = inputFinished(ptr)
fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples, opts.frame_opts.samp_freq)
fun getFrame(i: Int): FloatArray = getFrame(ptr, i)
fun getFrames(start: Int, n: Int): FloatArray = getFrames(ptr, start, n)
private external fun new(opts: FbankOptions): Long
private external fun delete(ptr: Long)
private external fun dim(ptr: Long): Int
private external fun frameShiftInSeconds(ptr: Long): Float
private external fun numFramesReady(ptr: Long): Int
private external fun isLastFrame(ptr: Long, i: Int): Boolean
private external fun inputFinished(ptr: Long)
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sample_rate: Float)
private external fun getFrame(ptr: Long, i: Int): FloatArray
private external fun getFrames(ptr: Long, start: Int, n: Int): FloatArray
companion object {
init {
System.loadLibrary("sherpa-ncnn")
}
}
}
./code/jni/feat-extractor/WaveReader.kt
class WaveReader {
companion object {
// Read a mono wave file.
// No resampling is made.
external fun readWave(filename: String, expected_sample_rate: Float = 16000.0f) : FloatArray?
init {
System.loadLibrary("sherpa-ncnn")
}
}
}
./code/jni/feat-extractor/Model.kt
data class ModelConfig(
var encoderParam: String,
var encoderBin: String,
var decoderParam: String,
var decoderBin: String,
var joinerParam: String,
var joinerBin: String,
var numThreads: Int = 4,
)
class Model(var config: ModelConfig) {
private var ptr: Long
init {
ptr = new(config)
}
protected fun finalize() {
delete(ptr)
}
val segment: Int
get() = segment(ptr)
val offset: Int
get() = offset(ptr)
private external fun new(config: ModelConfig): Long
private external fun delete(ptr: Long)
private external fun segment(ptr: Long): Int
private external fun offset(ptr: Long): Int
}