Binding kaldi-native-fbank for kotlin

This note shows how to bind https://github.com/csukuangfj/kaldi-native-fbank for kotlin.

./code/jni/feat-extractor/Makefile
SHERPA_NCNN_INSTALL_DIR := /ceph-fj/fangjun/open-source/sherpa-ncnn/build/install

CXXFLAGS := -I $(JAVA_HOME)/include
CXXFLAGS += -I $(JAVA_HOME)/include/linux
CXXFLAGS += -I $(SHERPA_NCNN_INSTALL_DIR)/include
CXXFLAGS += -I $(SHERPA_NCNN_INSTALL_DIR)/include/ncnn
CXXFLAGS += -Wall

LDFLAGS := -L $(SHERPA_NCNN_INSTALL_DIR)/lib -lkaldi-native-fbank-core -lsherpa-ncnn-core -lncnn
LDFLAGS += -Wl,-rpath,$(SHERPA_NCNN_INSTALL_DIR)/lib

all: main.jar libsherpa-ncnn.so

main.jar: Main.kt OnlineFeature.kt WaveReader.kt Model.kt
	kotlinc-jvm -include-runtime -d main.jar Main.kt OnlineFeature.kt WaveReader.kt Model.kt

libsherpa-ncnn.so: online-feature.cc sherpa-ncnn.cc online-feature.h sherpa-ncnn.h
	$(CXX) -o $@ -shared -fPIC $(CXXFLAGS) online-feature.cc sherpa-ncnn.cc $(LDFLAGS)

run: all
	java -jar main.jar

clean:
	$(RM) main.jar libsherpa-ncnn.so
./code/jni/feat-extractor/online-feature.h
#ifndef ONLINE_FEATURE_H_
#define ONLINE_FEATURE_H_
#include "jni.h"

#ifdef __cplusplus
extern "C" {
#endif

JNIEXPORT jlong JNICALL Java_OnlineFbank_new(JNIEnv *env, jobject obj,
                                             jobject opts);

JNIEXPORT void JNICALL Java_OnlineFbank_delete(JNIEnv *env, jobject obj,
                                               jlong ptr);

JNIEXPORT jint JNICALL Java_OnlineFbank_dim(JNIEnv *env, jobject obj,
                                            jlong ptr);

JNIEXPORT jfloat JNICALL Java_OnlineFbank_frameShiftInSeconds(JNIEnv *env,
                                                              jobject obj,
                                                              jlong ptr);

JNIEXPORT jint JNICALL Java_OnlineFbank_numFramesReady(JNIEnv *env, jobject obj,
                                                       jlong ptr);

JNIEXPORT jboolean JNICALL Java_OnlineFbank_isLastFrame(JNIEnv *env,
                                                        jobject obj, jlong ptr,
                                                        jint i);

JNIEXPORT void JNICALL Java_OnlineFbank_inputFinished(JNIEnv *env, jobject obj,
                                                      jlong ptr);

JNIEXPORT void JNICALL Java_OnlineFbank_acceptWaveform(JNIEnv *env, jobject obj,
                                                       jlong ptr,
                                                       jfloatArray samples,
                                                       jfloat sample_rate);

JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrame(JNIEnv *env,
                                                        jobject obj, jlong ptr,
                                                        jint i);

JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrames(JNIEnv *env,
                                                         jobject /*obj*/,
                                                         jlong ptr, jint start,
                                                         jint n);

#ifdef __cplusplus
}
#endif

#endif // ONLINE_FEATURE_H_
./code/jni/feat-extractor/online-feature.cc
#include "online-feature.h"
#include "kaldi-native-fbank/csrc/online-feature.h"

JNIEXPORT jlong JNICALL Java_OnlineFbank_new(JNIEnv *env, jobject /*obj*/,
                                             jobject opts) {
  jclass cls = env->GetObjectClass(opts);
  jfieldID fid;

  // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
  // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html

  knf::FbankOptions fbank_opts;

  fid = env->GetFieldID(cls, "use_energy", "Z");
  fbank_opts.use_energy = env->GetBooleanField(opts, fid);

  fid = env->GetFieldID(cls, "energy_floor", "F");
  fbank_opts.energy_floor = env->GetFloatField(opts, fid);

  fid = env->GetFieldID(cls, "raw_energy", "Z");
  fbank_opts.raw_energy = env->GetBooleanField(opts, fid);

  fid = env->GetFieldID(cls, "htk_compat", "Z");
  fbank_opts.htk_compat = env->GetBooleanField(opts, fid);

  fid = env->GetFieldID(cls, "use_log_fbank", "Z");
  fbank_opts.use_log_fbank = env->GetBooleanField(opts, fid);

  fid = env->GetFieldID(cls, "use_power", "Z");
  fbank_opts.use_power = env->GetBooleanField(opts, fid);

  fid = env->GetFieldID(cls, "frame_opts", "LFrameExtractionOptions;");

  jobject frame_opts = env->GetObjectField(opts, fid);
  jclass frame_opts_cls = env->GetObjectClass(frame_opts);

  fid = env->GetFieldID(frame_opts_cls, "samp_freq", "F");
  fbank_opts.frame_opts.samp_freq = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "frame_shift_ms", "F");
  fbank_opts.frame_opts.frame_shift_ms = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "frame_length_ms", "F");
  fbank_opts.frame_opts.frame_length_ms = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "dither", "F");
  fbank_opts.frame_opts.dither = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "preemph_coeff", "F");
  fbank_opts.frame_opts.preemph_coeff = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "remove_dc_offset", "Z");
  fbank_opts.frame_opts.remove_dc_offset =
      env->GetBooleanField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "window_type", "Ljava/lang/String;");
  jstring window_type = (jstring)env->GetObjectField(frame_opts, fid);
  const char *p_window_type = env->GetStringUTFChars(window_type, nullptr);
  fbank_opts.frame_opts.window_type = p_window_type;
  env->ReleaseStringUTFChars(window_type, p_window_type);

  fid = env->GetFieldID(frame_opts_cls, "round_to_power_of_two", "Z");
  fbank_opts.frame_opts.round_to_power_of_two =
      env->GetBooleanField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "blackman_coeff", "F");
  fbank_opts.frame_opts.blackman_coeff = env->GetFloatField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "snip_edges", "Z");
  fbank_opts.frame_opts.snip_edges = env->GetBooleanField(frame_opts, fid);

  fid = env->GetFieldID(frame_opts_cls, "max_feature_vectors", "I");
  fbank_opts.frame_opts.max_feature_vectors = env->GetIntField(frame_opts, fid);

  fid = env->GetFieldID(cls, "mel_opts", "LMelBanksOptions;");
  jobject mel_opts = env->GetObjectField(opts, fid);
  jclass mel_opts_cls = env->GetObjectClass(mel_opts);

  fid = env->GetFieldID(mel_opts_cls, "num_bins", "I");
  fbank_opts.mel_opts.num_bins = env->GetIntField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "low_freq", "F");
  fbank_opts.mel_opts.low_freq = env->GetFloatField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "high_freq", "F");
  fbank_opts.mel_opts.high_freq = env->GetFloatField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "vtln_low", "F");
  fbank_opts.mel_opts.vtln_low = env->GetFloatField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "vtln_high", "F");
  fbank_opts.mel_opts.vtln_high = env->GetFloatField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "debug_mel", "Z");
  fbank_opts.mel_opts.debug_mel = env->GetBooleanField(mel_opts, fid);

  fid = env->GetFieldID(mel_opts_cls, "htk_mode", "Z");
  fbank_opts.mel_opts.htk_mode = env->GetBooleanField(mel_opts, fid);

  auto online_fbank = new knf::OnlineFbank(fbank_opts);

  return (jlong)online_fbank;
}

JNIEXPORT void JNICALL Java_OnlineFbank_delete(JNIEnv *env, jobject obj,
                                               jlong ptr) {
  delete reinterpret_cast<knf::OnlineFbank *>(ptr);
}

JNIEXPORT jint JNICALL Java_OnlineFbank_dim(JNIEnv *env, jobject obj,
                                            jlong ptr) {
  return reinterpret_cast<const knf::OnlineFbank *>(ptr)->Dim();
}

JNIEXPORT jfloat JNICALL Java_OnlineFbank_frameShiftInSeconds(JNIEnv *env,
                                                              jobject obj,
                                                              jlong ptr) {
  return reinterpret_cast<const knf::OnlineFbank *>(ptr)->FrameShiftInSeconds();
}

JNIEXPORT jint JNICALL Java_OnlineFbank_numFramesReady(JNIEnv *env, jobject obj,
                                                       jlong ptr) {
  return reinterpret_cast<const knf::OnlineFbank *>(ptr)->NumFramesReady();
}

JNIEXPORT jboolean JNICALL Java_OnlineFbank_isLastFrame(JNIEnv *env,
                                                        jobject obj, jlong ptr,
                                                        jint i) {
  return reinterpret_cast<const knf::OnlineFbank *>(ptr)->IsLastFrame(i);
}

JNIEXPORT void JNICALL Java_OnlineFbank_inputFinished(JNIEnv *env, jobject obj,
                                                      jlong ptr) {
  reinterpret_cast<knf::OnlineFbank *>(ptr)->InputFinished();
}

JNIEXPORT void JNICALL Java_OnlineFbank_acceptWaveform(JNIEnv *env, jobject obj,
                                                       jlong ptr,
                                                       jfloatArray samples,
                                                       jfloat sample_rate) {
  jfloat *p = env->GetFloatArrayElements(samples, nullptr);
  jsize n = env->GetArrayLength(samples);

  reinterpret_cast<knf::OnlineFbank *>(ptr)->AcceptWaveform(sample_rate, p, n);

  env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}

JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrame(JNIEnv *env,
                                                        jobject obj, jlong ptr,
                                                        jint i) {
  auto online_fbank = reinterpret_cast<const knf::OnlineFbank *>(ptr);
  auto frame = online_fbank->GetFrame(i);
  auto dim = online_fbank->Dim();

  jfloatArray ans = env->NewFloatArray(dim);
  env->SetFloatArrayRegion(ans, 0, dim, frame);

  return ans;
}

JNIEXPORT jfloatArray JNICALL Java_OnlineFbank_getFrames(JNIEnv *env,
                                                         jobject /*obj*/,
                                                         jlong ptr, jint start,
                                                         jint n) {
  auto online_fbank = reinterpret_cast<const knf::OnlineFbank *>(ptr);
  auto dim = online_fbank->Dim();

  if (start + n > online_fbank->NumFramesReady()) {
    return nullptr;
  }

  jfloatArray ans = env->NewFloatArray(n * dim);
  for (int32_t i = 0; i != n; ++i) {
    auto frame = online_fbank->GetFrame(start + i);
    env->SetFloatArrayRegion(ans, i * dim, dim, frame);
  }

  return ans;
}
./code/jni/feat-extractor/sherpa-ncnn.h
#ifndef SHERPA_NCNN_H_
#define SHERPA_NCNN_H_
#include "jni.h"

#ifdef __cplusplus
extern "C" {
#endif

JNIEXPORT jfloatArray JNICALL Java_WaveReader_00024Companion_readWave(
    JNIEnv *env, jclass cls, jstring filename, jfloat expected_sample_rate);

JNIEXPORT jlong JNICALL Java_Model_new(JNIEnv *env, jobject /*obj*/,
                                       jobject config);

JNIEXPORT void JNICALL Java_Model_delete(JNIEnv *env, jobject /*obj*/,
                                         jlong ptr);

JNIEXPORT jint JNICALL Java_Model_segment(JNIEnv *env, jobject /*obj*/,
                                          jlong ptr);

JNIEXPORT jint JNICALL Java_Model_offset(JNIEnv *env, jobject /*obj*/,
                                         jlong ptr);

#ifdef __cplusplus
}
#endif

#endif // SHERPA_NCNN_H_
./code/jni/feat-extractor/sherpa-ncnn.cc
#include "sherpa-ncnn.h"
#include "sherpa-ncnn/csrc/model.h"
#include "sherpa-ncnn/csrc/wave-reader.h"
#include <iostream>

JNIEXPORT jfloatArray JNICALL Java_WaveReader_00024Companion_readWave(
    JNIEnv *env, jclass cls, jstring filename, jfloat expected_sample_rate) {

  const char *p_filename = env->GetStringUTFChars(filename, nullptr);

  bool is_ok = false;
  std::vector<float> samples =
      sherpa_ncnn::ReadWave(p_filename, expected_sample_rate, &is_ok);
  env->ReleaseStringUTFChars(filename, p_filename);

  if (!is_ok) {
    return nullptr;
  }

  jfloatArray ans = env->NewFloatArray(samples.size());
  env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
  return ans;
}

JNIEXPORT jlong JNICALL Java_Model_new(JNIEnv *env, jobject /*obj*/,
                                       jobject config) {
  sherpa_ncnn::ModelConfig model_config;

  jclass cls = env->GetObjectClass(config);

  jfieldID fid = env->GetFieldID(cls, "encoderParam", "Ljava/lang/String;");
  jstring s = (jstring)env->GetObjectField(config, fid);
  const char *p = env->GetStringUTFChars(s, nullptr);
  model_config.encoder_param = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "encoderBin", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  model_config.encoder_bin = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "decoderParam", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  model_config.decoder_param = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "decoderBin", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  model_config.decoder_bin = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "joinerParam", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  model_config.joiner_param = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "joinerBin", "Ljava/lang/String;");
  s = (jstring)env->GetObjectField(config, fid);
  p = env->GetStringUTFChars(s, nullptr);
  model_config.joiner_bin = p;
  env->ReleaseStringUTFChars(s, p);

  fid = env->GetFieldID(cls, "numThreads", "I");
  model_config.num_threads = env->GetIntField(config, fid);
  std::cout << model_config.ToString() << "\n";

  auto model = sherpa_ncnn::Model::Create(model_config);

  return (jlong)(model.release());
}

JNIEXPORT void JNICALL Java_Model_delete(JNIEnv *env, jobject /*obj*/,
                                         jlong ptr) {
  delete reinterpret_cast<sherpa_ncnn::Model *>(ptr);
}

JNIEXPORT jint JNICALL Java_Model_segment(JNIEnv *env, jobject /*obj*/,
                                          jlong ptr) {
  return reinterpret_cast<const sherpa_ncnn::Model *>(ptr)->Segment();
}

JNIEXPORT jint JNICALL Java_Model_offset(JNIEnv *env, jobject /*obj*/,
                                         jlong ptr) {
  return reinterpret_cast<const sherpa_ncnn::Model *>(ptr)->Offset();
}
./code/jni/feat-extractor/Main.kt
fun main() {
  var fbank_opts = FbankOptions()
  fbank_opts.mel_opts.num_bins = 80

	var online_fbank = OnlineFbank(fbank_opts)

	var samples = WaveReader.readWave("./1089-134686-0001.wav", 16000.0f)
	if(samples != null) {
		online_fbank.acceptWaveform(samples)
	}
	online_fbank.inputFinished()

	var modelConfig = ModelConfig(
			encoderParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
			encoderBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
			decoderParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
			decoderBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
			joinerParam="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param",
			joinerBin="./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin",
			numThreads=4,
	)
	var model = Model(modelConfig)
	println("segment: ${model.segment}")
	println("offset: ${model.offset}")
}
./code/jni/feat-extractor/OnlineFeature.kt
data class FrameExtractionOptions(
  var samp_freq: Float = 16000.0f,
  var frame_shift_ms: Float = 10.0f,
  var frame_length_ms: Float = 25.0f,
  var dither: Float = 0.0f,
  var preemph_coeff: Float = 0.97f,
  var remove_dc_offset: Boolean = true,
  var window_type: String = "povey",
  var round_to_power_of_two: Boolean = true,
  var blackman_coeff: Float = 0.42f,
  var snip_edges: Boolean = true,
  var max_feature_vectors: Int = -1
)

data class MelBanksOptions(
  var num_bins : Int = 25,
  var low_freq : Float = 20.0f,
  var high_freq : Float = 0.0f,
  var vtln_low : Float = 100.0f,
  var vtln_high : Float = -500.0f,
  var debug_mel : Boolean = false,
  var htk_mode : Boolean = false,
)

data class FbankOptions(
  var frame_opts: FrameExtractionOptions = FrameExtractionOptions(),
  var mel_opts: MelBanksOptions = MelBanksOptions(),
  var use_energy: Boolean = false,
  var energy_floor: Float = 0.0f,
  var raw_energy: Boolean = true,
  var htk_compat: Boolean = false,
  var use_log_fbank: Boolean = true,
  var use_power: Boolean = true,
)


class OnlineFbank(var opts: FbankOptions) {
  private var ptr: Long

  init {
    ptr = new(opts)
  }

  protected fun finalize() {
    delete(ptr)
  }

  val dim: Int
      get() = dim(ptr)

  val frameShiftInSeconds: Float
      get() = frameShiftInSeconds(ptr)

  val numFramesReady: Int
      get() = numFramesReady(ptr)

  fun isLastFrame(i: Int) :Boolean = isLastFrame(ptr, i)
  fun inputFinished() = inputFinished(ptr)
  fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples, opts.frame_opts.samp_freq)
  fun getFrame(i: Int): FloatArray = getFrame(ptr, i)
  fun getFrames(start: Int, n: Int): FloatArray = getFrames(ptr, start, n)

  private external fun new(opts: FbankOptions): Long
  private external fun delete(ptr: Long)
  private external fun dim(ptr: Long): Int
  private external fun frameShiftInSeconds(ptr: Long): Float
  private external fun numFramesReady(ptr: Long): Int
  private external fun isLastFrame(ptr: Long, i: Int): Boolean
  private external fun inputFinished(ptr: Long)
  private external fun acceptWaveform(ptr: Long, samples: FloatArray, sample_rate: Float)
  private external fun getFrame(ptr: Long, i: Int): FloatArray
  private external fun getFrames(ptr: Long, start: Int, n: Int): FloatArray

  companion object {
    init {
      System.loadLibrary("sherpa-ncnn")
    }
  }
}
./code/jni/feat-extractor/WaveReader.kt
class WaveReader {

  companion object {
    // Read a mono wave file.
    // No resampling is made.
    external fun readWave(filename: String, expected_sample_rate: Float = 16000.0f) : FloatArray?

    init {
      System.loadLibrary("sherpa-ncnn")
    }
  }
}
./code/jni/feat-extractor/Model.kt
data class ModelConfig(
  var encoderParam: String,
  var encoderBin: String,
  var decoderParam: String,
  var decoderBin: String,
  var joinerParam: String,
  var joinerBin: String,
  var numThreads: Int = 4,
)

class Model(var config: ModelConfig) {
  private var ptr: Long

  init {
    ptr = new(config)
  }

  protected fun finalize() {
    delete(ptr)
  }

  val segment: Int
      get() = segment(ptr)

  val offset: Int
      get() = offset(ptr)

  private external fun new(config: ModelConfig): Long
  private external fun delete(ptr: Long)
  private external fun segment(ptr: Long): Int
  private external fun offset(ptr: Long): Int

}