Java Android Tokenizer Code

Created on: September 28, 2024

Answered using OpenAI 01 Preview by OpenAI01

Question

Rewrite all this code in Java for Android:

export class PreTrainedTokenizer {

constructor(tokenizerJSON, tokenizerConfig) {

this._tokenizer_config = tokenizerConfig;
this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer);
this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer);
this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig);
this.post_processor = PostProcessor.fromConfig(tokenizerJSON.post_processor);

this.special_tokens = [];
this.all_special_ids = [];


this.added_tokens = [];
for (const addedToken of tokenizerJSON.added_tokens) {
const token = new AddedToken(addedToken);
this.added_tokens.push(token);

this.model.tokens_to_ids.set(token.content, token.id);
this.model.vocab[token.id] = token.content;

if (token.special) {
this.special_tokens.push(token.content);
this.all_special_ids.push(token.id);
}
}

this.additional_special_tokens = tokenizerConfig.additional_special_tokens ?? [];
this.special_tokens.push(...this.additional_special_tokens);
this.special_tokens = [...new Set(this.special_tokens)];

this.added_tokens_regex = this.added_tokens.length > 0 ? new RegExp(
this.added_tokens.map(x => `${x.lstrip ? '\\s*' : ''}(${escapeRegExp(x.content)})${x.rstrip ? '\\s*' : ''}`).join('|')
) : null;

this.mask_token = this.getToken('mask_token');
this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token);

this.pad_token = this.getToken('pad_token', 'eos_token');
this.pad_token_id = this.model.tokens_to_ids.get(this.pad_token);

this.sep_token = this.getToken('sep_token');
this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token);

this.unk_token = this.getToken('unk_token');
this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token);

this.model_max_length = tokenizerConfig.model_max_length;


this.remove_space = tokenizerConfig.remove_space;

this.clean_up_tokenization_spaces = tokenizerConfig.clean_up_tokenization_spaces ?? true;
this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false;

this.padding_side = 'right';

this.legacy = false;

}

getToken(...keys) {
for (const key of keys) {
const item = this._tokenizer_config[key];

if (!item) continue;

if (typeof item === 'object') {
if (item.__type === 'AddedToken') {
return item.content;
} else {
throw Error(`Unknown token: ${item}`);
}
} else {
return item;
}
}
return null;
}

_encode_text(text) {
if (text === null) return null;

const sections = this.added_tokens_regex ? text.split(this.added_tokens_regex).filter(x => x) : [text];

const tokens = sections.map((x, section_index) => {
const addedToken = this.added_tokens.find(t => t.content === x);
if (addedToken !== undefined) {
return x
} else {
if (this.remove_space === true) {
x = x.trim().split(/\s+/).join(' ');
}
if (this.do_lowercase_and_remove_accent) {
x = lowercase_and_remove_accent(x);
}

if (this.normalizer !== null) {
x = this.normalizer.normalize(x);
}

if (x.length === 0) {
return [];
}

const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer.pre_tokenize(x, {
section_index,
}) : [x];

const tokens = this.model._call(sectionTokens);

return tokens;
}
}).flat();

return tokens;
}

_encode_plus(text, {
add_special_tokens = true,
} = {}) {
const tokens = this._encode_text(text);

const combinedTokens = this.post_processor
? this.post_processor.post_process(tokens, null, { add_special_tokens })
: { tokens };

console.log(combinedTokens.tokens);
const input_ids = this.model.convert_tokens_to_ids(combinedTokens.tokens);

return {
input_ids,
attention_mask: new Array(input_ids.length).fill(1),
};
}

encode(text, {
add_special_tokens = true,
} = {}) {
const { input_ids } = this._encode_plus(text, {
add_special_tokens,
});
return input_ids;
}

}

/**
* Helper function to remove accents from a string.
* @param {string} text The text to remove accents from.
* @returns {string} The text with accents removed.
*/
function remove_accents(text) {
return text.replace(/[\u0300-\u036f]/g, '');
}

/**
* Helper function to lowercase a string and remove accents.
* @param {string} text The text to lowercase and remove accents from.
* @returns {string} The lowercased text with accents removed.
*/
function lowercase_and_remove_accent(text) {
return remove_accents(text.toLowerCase());
}

class Normalizer {

constructor(config) {
this.config = config;
}


static fromConfig(config) {
if (config === null) return null;
switch (config.type) {
case 'Precompiled':
return new Precompiled(config);
default:
throw new Error(`Unknown Normalizer type: ${config.type}`);
}
}


normalize(text) {
throw Error("normalize should be implemented in subclass.")
}

}

class Precompiled extends Normalizer {

constructor(config) {
super(config);
//this.charsmap = config.precompiled_charsmap;
}


normalize(text) {
text = text.replace(/[\u0001-\u0008\u000B\u000E-\u001F\u007F\u008F\u009F]/gm, ''); // Remove control characters
text = text.replace(/[\u0009\u000A\u000C\u000D\u1680\u200B\u200C\u200E\u200F\u2028\u2029\u2581\uFEFF\uFFFD]/gm, '\u0020'); // Replace certain characters with a space

if (text.includes('\uFF5E')) {
const parts = text.split('\uFF5E');
text = parts.map(part => part.normalize('NFKC')).join('\uFF5E');
} else {
text = text.normalize('NFKC');
}

return text;
}
}

class PreTokenizer {

static fromConfig(config) {
if (config === null) return null;

switch (config.type) {
case 'Sequence':
return new PreTokenizerSequence(config);
case 'WhitespaceSplit':
return new WhitespaceSplit(config);
case 'Metaspace':
return new MetaspacePreTokenizer(config);
default:
throw new Error(`Unknown PreTokenizer type: ${config.type}`);
}
}


pre_tokenize_text(text, options) {
throw Error("pre_tokenize_text should be implemented in subclass.")
}


pre_tokenize(text, options) {
return (Array.isArray(text)
? text.map(x => this.pre_tokenize_text(x, options))
: this.pre_tokenize_text(text, options)
).flat();
}

}

class PreTokenizerSequence extends PreTokenizer {

constructor(config) {
super();
this.tokenizers = config.pretokenizers.map(x => PreTokenizer.fromConfig(x));
}


pre_tokenize_text(text, options) {
return this.tokenizers.reduce((preTokenizedText, tokenizer) => {
return tokenizer.pre_tokenize(preTokenizedText, options);
}, [text]);
}
}

class WhitespaceSplit extends PreTokenizer {

constructor(config) {
super();
}

pre_tokenize_text(text, options) {
return whitespace_split(text);
}
}

class MetaspacePreTokenizer extends PreTokenizer {

constructor(config) {
super();

this.addPrefixSpace = config.add_prefix_space;
this.replacement = config.replacement;
this.strRep = config.str_rep || this.replacement;
this.prepend_scheme = config.prepend_scheme ?? 'always';
}


pre_tokenize_text(text, {
section_index = undefined,
} = {}) {

let normalized = text.replaceAll(' ', this.strRep);

if (
(this.addPrefixSpace && !normalized.startsWith(this.replacement))

&& (
this.prepend_scheme === 'always' ||
(this.prepend_scheme === 'first' && section_index === 0)
)
) {
normalized = this.strRep + normalized;
}
return [normalized];
}
}

class TokenizerModel {

constructor(config) {
this.config = config;


this.vocab = [];


this.tokens_to_ids = new Map();

this.unk_token_id = undefined;
this.unk_token = undefined;
this.end_of_word_suffix = undefined;


this.fuse_unk = this.config.fuse_unk ?? false;
}


static fromConfig(config, ...args) {
switch (config.type) {
case 'Unigram':
return new Unigram(config, ...args);
default:
throw new Error(`Unknown TokenizerModel type: ${config.type}`);
}
}


_call(tokens) {
let ids = this.encode(tokens);
if (this.fuse_unk) {
ids = fuse(ids, this.unk_token_id, this.tokens_to_ids);
}
return ids;
}


encode(tokens) {
throw Error("encode should be implemented in subclass.")
}


convert_tokens_to_ids(tokens) {
return tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id);
}


convert_ids_to_tokens(ids) {
return ids.map(i => this.vocab[i] ?? this.unk_token);
}
}

class Unigram extends TokenizerModel {

constructor(config, moreConfig) {
super(config);

const vocabSize = config.vocab.length;
this.vocab = new Array(vocabSize);
this.scores = new Array(vocabSize);
for (let i = 0; i < vocabSize; ++i) {
const piece = config.vocab[i];
this.vocab[i] = piece[0];
this.scores[i] = piece[1];
}


this.unk_token_id = config.unk_id;
this.unk_token = this.vocab[config.unk_id];

this.tokens_to_ids = new Map(this.vocab.map((x, i) => [x, i]));
this.bosToken = ' ';

this.bosTokenId = this.tokens_to_ids.get(this.bosToken);
this.eosToken = moreConfig.eos_token;

this.eosTokenId = this.tokens_to_ids.get(this.eosToken);
this.unkToken = this.vocab[this.unk_token_id];

this.minScore = min(this.scores)[0];

this.unkScore = this.minScore - 10.0;
this.scores[this.unk_token_id] = this.unkScore;

this.trie = new CharTrie();
this.trie.extend(this.vocab);

this.fuse_unk = true;
}


populateNodes(lattice) {
const sentence = lattice.sentence;
const len = sentence.length;
let beginPos = 0;
while (beginPos < len) {
const mblen = 1;
let hasSingleNode = false;
const tokens = [];

for (let token of this.trie.commonPrefixSearch(sentence.slice(beginPos))) {
tokens.push(token);
const tokenId = this.tokens_to_ids.get(token);
const tokenScore = this.scores[tokenId];
const n = token.length;
lattice.insert(beginPos, n, tokenScore, tokenId);
if (!hasSingleNode && n === mblen) {
hasSingleNode = true;
}
}
if (!hasSingleNode) {
lattice.insert(beginPos, mblen, this.unkScore, this.unk_token_id);
}
beginPos += mblen;
}
}


tokenize(normalized) {
const lattice = new TokenLattice(normalized, this.bosTokenId, this.eosTokenId);
this.populateNodes(lattice);
return lattice.tokens();
}


encode(tokens) {
const toReturn = [];
for (const token of tokens) {
const tokenized = this.tokenize(token);
toReturn.push(...tokenized);
}
return toReturn;
}

}

function min(arr) {
if (arr.length === 0) throw Error('Array must not be empty');
let min = arr[0];
let indexOfMin = 0;
for (let i = 1; i < arr.length; ++i) {
if (arr[i] < min) {
min = arr[i];
indexOfMin = i;
}
}
return [min, indexOfMin];
}

class CharTrie {

constructor() {
this.root = CharTrieNode.default();
}

extend(texts) {

if (texts.length === 0) return;

for (let text of texts) {
let node = this.root;
for (let ch of text) {
let child = node.children.get(ch);
if (child === undefined) {
child = CharTrieNode.default();
node.children.set(ch, child);
}
node = child;
}
node.isLeaf = true;
}

}

*commonPrefixSearch(text) {
let node = this.root;
let prefix = "";
for (let i = 0; i < text.length && node !== undefined; ++i) {
const ch = text[i];
prefix += ch;
node = node.children.get(ch);
if (node !== undefined && node.isLeaf) {
yield prefix;
}
}
}
}

class CharTrieNode {

constructor(isLeaf, children) {
this.isLeaf = isLeaf;
this.children = children;
}


static default() {
return new CharTrieNode(false, new Map());
}

}

class PostProcessor {


constructor(config) {
this.config = config;
}


static fromConfig(config) {
if (config === null) return null;
switch (config.type) {
case 'TemplateProcessing':
return new TemplateProcessing(config);
default:
throw new Error(`Unknown PostProcessor type: ${config.type}`);
}
}


post_process(tokens, ...args) {
throw Error("post_process should be implemented in subclass.")
}

}

class TemplateProcessing extends PostProcessor {

constructor(config) {
super(config);

this.single = config.single;
this.pair = config.pair;
}


post_process(tokens, tokens_pair = null, {
add_special_tokens = true,
} = {}) {
const type = tokens_pair === null ? this.single : this.pair

//Here I decided to display tokens and type variables in the console to make it easier to understand their structure
console.log(1, tokens); //1 ["▁I", "▁like", "▁cat", "s"]
console.log(2, type); //2 [{"SpecialToken": {"id": "<s>", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "</s>", "type_id": 0}}]

let processedTokens = [];
let types = [];
for (const item of type) {
if ('SpecialToken' in item) {
if (add_special_tokens) {
processedTokens.push(item.SpecialToken.id);
types.push(item.SpecialToken.type_id);
}
} else if ('Sequence' in item) {
if (item.Sequence.id === 'A') {
processedTokens = mergeArrays(processedTokens, tokens);
types = mergeArrays(types, new Array(tokens.length).fill(item.Sequence.type_id));

} else if (item.Sequence.id === 'B') {
processedTokens = mergeArrays(processedTokens, tokens_pair);
types = mergeArrays(types, new Array(tokens_pair.length).fill(item.Sequence.type_id));
}
}
}

//Here I decided to display tokens and type variables in the console to make it easier to understand their structure
console.log(3, processedTokens) //3 ["<s>", "▁I", "▁like", "▁cat", "s", "</s>"]
console.log(4, types)//4 [0, 0, 0, 0, 0, 0]

return { tokens: processedTokens, token_type_ids: types };
}
}

class AddedToken {

constructor(config) {
this.content = config.content;
this.id = config.id;
this.single_word = config.single_word ?? false;
this.lstrip = config.lstrip ?? false;
this.rstrip = config.rstrip ?? false;
this.special = config.special ?? false;
this.normalized = config.normalized ?? null;
}
}

function escapeRegExp(string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
}

function whitespace_split(text) {
return text.match(/\S+/g) || [];
}

class TokenLattice {

constructor(sentence, bosTokenId, eosTokenId) {
this.sentence = sentence;
this.len = sentence.length;
this.bosTokenId = bosTokenId;
this.eosTokenId = eosTokenId;
this.nodes = [];
this.beginNodes = Array.from({ length: this.len + 1 }, () => []);
this.endNodes = Array.from({ length: this.len + 1 }, () => []);

const bos = new TokenLatticeNode(this.bosTokenId, 0, 0, 0, 0.0);
const eos = new TokenLatticeNode(this.eosTokenId, 1, this.len, 0, 0.0);
this.nodes.push(bos.clone());
this.nodes.push(eos.clone());
this.beginNodes[this.len].push(eos);
this.endNodes[0].push(bos);
}


insert(pos, length, score, tokenId) {
const nodeId = this.nodes.length;
const node = new TokenLatticeNode(tokenId, nodeId, pos, length, score);
this.beginNodes[pos].push(node);
this.endNodes[pos + length].push(node);
this.nodes.push(node);
}


viterbi() {
const len = this.len;
let pos = 0;
while (pos <= len) {
if (this.beginNodes[pos].length == 0) {
return [];
}
for (let rnode of this.beginNodes[pos]) {
rnode.prev = null;
let bestScore = 0.0;
let bestNode = null;
for (let lnode of this.endNodes[pos]) {
const score = lnode.backtraceScore + rnode.score;
if (bestNode === null || score > bestScore) {
bestNode = lnode.clone();
bestScore = score;
}
}

if (bestNode !== null) {
rnode.prev = bestNode;
rnode.backtraceScore = bestScore;
} else {
return [];
}
}
++pos;
}

const results = [];
const root = this.beginNodes[len][0];
const prev = root.prev;
if (prev === null) {
return [];
}

let node = prev.clone();
while (node.prev !== null) {
results.push(node.clone());
const n = node.clone();
node = n.prev.clone();
}

results.reverse();
return results;
}


piece(node) {
return this.sentence.slice(node.pos, node.pos + node.length);
}


tokens() {
const nodes = this.viterbi();
return nodes.map(x => this.piece(x));
}


tokenIds() {
const nodes = this.viterbi();
return nodes.map(x => x.tokenId);
}
}

class TokenLatticeNode {

constructor(tokenId, nodeId, pos, length, score) {
this.tokenId = tokenId;
this.nodeId = nodeId;
this.pos = pos;
this.length = length;
this.score = score;
this.prev = null;
this.backtraceScore = 0.0;
}


clone() {
const n = new TokenLatticeNode(this.tokenId, this.nodeId, this.pos, this.length, this.score);
n.prev = this.prev;
n.backtraceScore = this.backtraceScore;
return n;
}
}

function fuse(arr, value, mapping) {
const fused = [];
let i = 0;
while (i < arr.length) {
fused.push(arr[i])
if ((mapping.get(arr[i]) ?? value) !== value) {
++i;
continue;
}

while (i < arr.length && (mapping.get(arr[i]) ?? value) === value) {
++i;
}
}

return fused;
}



function mergeArrays(...arrs) {
return Array.prototype.concat.apply([], arrs);
}


I want to use this code you rewrote in this way:

package com.transformers;

import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.Promise;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.Arguments;
import com.transformers.tokenizer.PreTrainedTokenizer;

import android.os.Handler;
import android.os.Looper;
import android.util.Log;

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.json.JSONObject;
import org.json.JSONException;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.FileNotFoundException;

public class TokenizerModule extends ReactContextBaseJavaModule {

private final ReactApplicationContext reactContext;
private final ExecutorService executorService;
private static final String TAG = "TokenizerModule";
private PreTrainedTokenizer tokenizer;

public TokenizerModule(ReactApplicationContext reactContext) {
super(reactContext);
this.reactContext = reactContext;
this.executorService = Executors.newCachedThreadPool();
}

@Override
public String getName() {
return TAG;
}

@ReactMethod
public void initializeTokenizer(String tokenizerJSONPath, String tokenizerConfigPath, Promise promise) {
executorService.execute(() -> {
try {
JSONObject tokenizerJSON = readJSONFromFile(tokenizerJSONPath);
JSONObject tokenizerConfig = readJSONFromFile(tokenizerConfigPath);

tokenizer = new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);

new Handler(Looper.getMainLooper()).post(() ->
promise.resolve("Tokenizer initialized successfully")
);
} catch (FileNotFoundException e) {
new Handler(Looper.getMainLooper()).post(() ->
promise.reject("FILE_NOT_FOUND", "File not found: " + e.getMessage())
);
} catch (Exception e) {
new Handler(Looper.getMainLooper()).post(() ->
promise.reject("INITIALIZATION_ERROR", e.getMessage())
);
}
});
}

@ReactMethod
public void encode(String text, Promise promise) {
executorService.execute(() -> {
try {
if (tokenizer == null) {
throw new IllegalStateException("Tokenizer not initialized. Call initializeTokenizer first.");
}
List<Integer> encodedIds = tokenizer.encode(text);


WritableArray result = Arguments.createArray();
for (Integer id : encodedIds) {
result.pushInt(id);
}

new Handler(Looper.getMainLooper()).post(() ->
promise.resolve(result)
);
} catch (IllegalStateException e) {
new Handler(Looper.getMainLooper()).post(() ->
promise.reject("TOKENIZER_NOT_INITIALIZED", e.getMessage())
);
} catch (Exception e) {
new Handler(Looper.getMainLooper()).post(() ->
promise.reject("ENCODING_ERROR", e.getMessage())
);
}
});
}

@ReactMethod
public void cleanup() {
executorService.shutdown();
}

private JSONObject readJSONFromFile(String filePath) throws IOException, JSONException {
StringBuilder content = new StringBuilder();
try (BufferedReader br = new BufferedReader(new FileReader(filePath))) {
String line;
while ((line = br.readLine()) != null) {
content.append(line);
}
}

return new JSONObject(content.toString());
}



}

Thought

Migrating code

I’m translating JavaScript code for tokenization into Android Java, ensuring compatibility and smooth integration with the given package.

Ensuring compatibility

I’m focusing on verifying the rewritten Java code's alignment with the specified usage example in an Android module. Ensuring accuracy and adherence to safety and policy standards is crucial.

I’m identifying the structure of the JavaScript code, noting the class PreTrainedTokenizer with its methods and helper functions. I’m set to translate all classes into Java for Android compatibility.

Translating constructs

I’m crafting Java equivalents for JavaScript constructs in the PreTrainedTokenizer. Java's precision contrasts with JavaScript's dynamic typing. Considering the inclusion of classes like Normalizer and TokenizerModel, I plan to create Java classes for each.

Translating to Java

I’m defining class variables in Java, ensuring appropriate data structures, and explicitly setting types, unlike JavaScript. This involves parsing JSON files into Java objects.

Translating JavaScript to Java

I’m converting JavaScript tokenizer code to Java, carefully assigning values to private members for tokenizer_config, normalizer, pre_tokenizer, model, post_processor, and initializing lists for special_tokens, all_special_ids, and added_tokens.

Establishing configurations

I’m gathering various components to configure a tokenizer. This involves setting values for normalization, pre-tokenization, model handling, special tokens, and regex patterns, ensuring each component is accurately initialized.

Focusing on token retrieval

I’m working through the implementation of the getToken method, which is essential for token handling in the system.

Translating functions

First, I exchanged JavaScript objects for Java objects, then translated the token retrieval function. Next, I embarked on a Java method for _encode_text, navigating through text normalization and tokenization to match JavaScript's logic.

Transforming text

Splitting the text into sections based on regex, keeping non-empty parts, and handling null regex cases in Java.

Transforming text

I'm parsing the text into sections, filtering out empty ones, and encoding each section into tokens, applying normalization and tokenization where needed.

Breaking down the encode method

I’m working through the encode method in Java, aligning it with the JavaScript function's logic. This involves handling text encoding, token processing, and converting tokens to IDs, with optional special token handling.

Answer

Sure! Below is the Java implementation of the provided JavaScript code, tailored for Android. The classes are organized under the com.transformers.tokenizer package to align with your usage example. This code rewrites the entire tokenizer logic, including all the helper classes and functions.

Please make sure to include all the necessary imports and handle any potential exceptions appropriately.


File Structure:

  • com.transformers.tokenizer.PreTrainedTokenizer.java
  • com.transformers.tokenizer.Normalizer.java
  • com.transformers.tokenizer.Precompiled.java
  • com.transformers.tokenizer.PreTokenizer.java
  • com.transformers.tokenizer.PreTokenizerSequence.java
  • com.transformers.tokenizer.WhitespaceSplit.java
  • com.transformers.tokenizer.MetaspacePreTokenizer.java
  • com.transformers.tokenizer.TokenizerModel.java
  • com.transformers.tokenizer.Unigram.java
  • com.transformers.tokenizer.CharTrie.java
  • com.transformers.tokenizer.CharTrieNode.java
  • com.transformers.tokenizer.PostProcessor.java
  • com.transformers.tokenizer.TemplateProcessing.java
  • com.transformers.tokenizer.AddedToken.java
  • com.transformers.tokenizer.TokenLattice.java
  • com.transformers.tokenizer.TokenLatticeNode.java

1. PreTrainedTokenizer.java

package com.transformers.tokenizer; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import java.text.Normalizer; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; public class PreTrainedTokenizer { private JSONObject _tokenizer_config; private NormalizerClass normalizer; private PreTokenizer pre_tokenizer; private TokenizerModel model; private PostProcessor post_processor; private List<String> special_tokens; private List<Integer> all_special_ids; private List<AddedToken> added_tokens; private Pattern added_tokens_regex; private String mask_token; private Integer mask_token_id; private String pad_token; private Integer pad_token_id; private String sep_token; private Integer sep_token_id; private String unk_token; private Integer unk_token_id; private int model_max_length; private boolean remove_space; private boolean clean_up_tokenization_spaces; private boolean do_lowercase_and_remove_accent; private String padding_side; private boolean legacy; public PreTrainedTokenizer(JSONObject tokenizerJSON, JSONObject tokenizerConfig) throws JSONException { this._tokenizer_config = tokenizerConfig; this.normalizer = NormalizerClass.fromConfig(tokenizerJSON.optJSONObject("normalizer")); this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.optJSONObject("pre_tokenizer")); this.model = TokenizerModel.fromConfig(tokenizerJSON.optJSONObject("model"), tokenizerConfig); this.post_processor = PostProcessor.fromConfig(tokenizerJSON.optJSONObject("post_processor")); this.special_tokens = new ArrayList<>(); this.all_special_ids = new ArrayList<>(); this.added_tokens = new ArrayList<>(); JSONArray addedTokensArray = tokenizerJSON.optJSONArray("added_tokens"); if (addedTokensArray != null) { for (int i = 0; i < addedTokensArray.length(); i++) { JSONObject addedTokenObj = addedTokensArray.getJSONObject(i); AddedToken token = new AddedToken(addedTokenObj); this.added_tokens.add(token); this.model.tokens_to_ids.put(token.content, token.id); this.model.vocab.put(token.id, token.content); if (token.special) { this.special_tokens.add(token.content); this.all_special_ids.add(token.id); } } } JSONArray additional_special_tokens = tokenizerConfig.optJSONArray("additional_special_tokens"); if (additional_special_tokens != null) { for (int i = 0; i < additional_special_tokens.length(); i++) { String token = additional_special_tokens.getString(i); this.special_tokens.add(token); } } // Remove duplicates Set<String> set = new HashSet<>(this.special_tokens); this.special_tokens = new ArrayList<>(set); // Build added_tokens_regex if (this.added_tokens.size() > 0) { StringBuilder patternBuilder = new StringBuilder(); for (AddedToken token : this.added_tokens) { if (patternBuilder.length() > 0) { patternBuilder.append("|"); } patternBuilder.append(token.lstrip ? "\\s*" : ""); patternBuilder.append("(").append(escapeRegExp(token.content)).append(")"); patternBuilder.append(token.rstrip ? "\\s*" : ""); } this.added_tokens_regex = Pattern.compile(patternBuilder.toString()); } else { this.added_tokens_regex = null; } // Set tokens this.mask_token = this.getToken("mask_token"); if (this.mask_token != null) { this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token); } else { this.mask_token_id = null; } this.pad_token = this.getToken("pad_token", "eos_token"); if (this.pad_token != null) { this.pad_token_id = this.model.tokens_to_ids.get(this.pad_token); } else { this.pad_token_id = null; } this.sep_token = this.getToken("sep_token"); if (this.sep_token != null) { this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); } else { this.sep_token_id = null; } this.unk_token = this.getToken("unk_token"); if (this.unk_token != null) { this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token); } else { this.unk_token_id = null; } this.model_max_length = tokenizerConfig.optInt("model_max_length"); this.remove_space = tokenizerConfig.optBoolean("remove_space"); this.clean_up_tokenization_spaces = tokenizerConfig.optBoolean("clean_up_tokenization_spaces", true); this.do_lowercase_and_remove_accent = tokenizerConfig.optBoolean("do_lowercase_and_remove_accent", false); this.padding_side = "right"; this.legacy = false; } private String getToken(String... keys) throws JSONException { for (String key : keys) { Object item = this._tokenizer_config.opt(key); if (item == null) continue; if (item instanceof JSONObject) { JSONObject itemObj = (JSONObject) item; String __type = itemObj.optString("__type"); if (__type.equals("AddedToken")) { return itemObj.optString("content"); } else { throw new IllegalArgumentException("Unknown token: " + item); } } else { return item.toString(); } } return null; } private List<String> _encode_text(String text) throws JSONException { if (text == null) return null; List<String> sections; if (this.added_tokens_regex != null) { Matcher matcher = this.added_tokens_regex.matcher(text); List<String> parts = new ArrayList<>(); int lastEnd = 0; while (matcher.find()) { if (lastEnd != matcher.start()) { parts.add(text.substring(lastEnd, matcher.start())); } parts.add(matcher.group()); lastEnd = matcher.end(); } if (lastEnd != text.length()) { parts.add(text.substring(lastEnd)); } sections = parts.stream().filter(s -> !s.isEmpty()).collect(Collectors.toList()); } else { sections = Collections.singletonList(text); } List<String> tokens = new ArrayList<>(); int section_index = 0; for (String x : sections) { boolean foundAddedToken = false; for (AddedToken t : this.added_tokens) { if (t.content.equals(x)) { tokens.add(x); foundAddedToken = true; break; } } if (!foundAddedToken) { if (this.remove_space) { x = x.trim().replaceAll("\\s+", " "); } if (this.do_lowercase_and_remove_accent) { x = lowercase_and_remove_accent(x); } if (this.normalizer != null) { x = this.normalizer.normalize(x); } if (x.length() == 0) { continue; } List<String> sectionTokens; if (this.pre_tokenizer != null) { sectionTokens = this.pre_tokenizer.pre_tokenize(x, section_index); } else { sectionTokens = Collections.singletonList(x); } List<String> modelTokens = this.model._call(sectionTokens); tokens.addAll(modelTokens); } section_index++; } return tokens; } private Map<String, Object> _encode_plus(String text, boolean add_special_tokens) throws JSONException { List<String> tokens = this._encode_text(text); Map<String, Object> combinedTokens; if (this.post_processor != null) { combinedTokens = this.post_processor.post_process(tokens, null, add_special_tokens); } else { combinedTokens = new HashMap<>(); combinedTokens.put("tokens", tokens); } @SuppressWarnings("unchecked") List<String> processedTokens = (List<String>) combinedTokens.get("tokens"); List<Integer> input_ids = this.model.convert_tokens_to_ids(processedTokens); Map<String, Object> result = new HashMap<>(); result.put("input_ids", input_ids); result.put("attention_mask", Collections.nCopies(input_ids.size(), 1)); return result; } public List<Integer> encode(String text) throws JSONException { Map<String, Object> encoded = this._encode_plus(text, true); @SuppressWarnings("unchecked") List<Integer> input_ids = (List<Integer>) encoded.get("input_ids"); return input_ids; } // Helper functions public static String remove_accents(String text) { String normalized = Normalizer.normalize(text, Normalizer.Form.NFD); return normalized.replaceAll("\\p{InCombiningDiacriticalMarks}+", ""); } public static String lowercase_and_remove_accent(String text) { return remove_accents(text.toLowerCase()); } public static String escapeRegExp(String string) { return string.replaceAll("[.*+?^${}()|\$$\$$\\\\]", "\\\\$0"); } public static List<String> whitespace_split(String text) { String[] tokens = text.trim().split("\\s+"); return Arrays.asList(tokens); } }

2. Normalizer.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; public abstract class NormalizerClass { protected JSONObject config; public NormalizerClass(JSONObject config) { this.config = config; } public static NormalizerClass fromConfig(JSONObject config) throws JSONException { if (config == null) return null; String type = config.optString("type"); switch (type) { case "Precompiled": return new Precompiled(config); default: throw new IllegalArgumentException("Unknown Normalizer type: " + type); } } public abstract String normalize(String text); }

3. Precompiled.java

package com.transformers.tokenizer; import org.json.JSONObject; import java.text.Normalizer; public class Precompiled extends NormalizerClass { public Precompiled(JSONObject config) { super(config); } @Override public String normalize(String text) { // Remove control characters text = text.replaceAll("[\\u0001-\\u0008\\u000B\\u000E-\\u001F\\u007F\\u008F\\u009F]", ""); // Replace certain characters with a space text = text.replaceAll("[\\u0009\\u000A\\u000C\\u000D\\u1680\\u200B\\u200C\\u200E\\u200F\\u2028\\u2029\\u2581\\uFEFF\\uFFFD]", " "); if (text.contains("\uFF5E")) { String[] parts = text.split("\uFF5E"); for (int i = 0; i < parts.length; i++) { parts[i] = Normalizer.normalize(parts[i], Normalizer.Form.NFKC); } text = String.join("\uFF5E", parts); } else { text = Normalizer.normalize(text, Normalizer.Form.NFKC); } return text; } }

4. PreTokenizer.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; import java.util.List; public abstract class PreTokenizer { public static PreTokenizer fromConfig(JSONObject config) throws JSONException { if (config == null) return null; String type = config.optString("type"); switch (type) { case "Sequence": return new PreTokenizerSequence(config); case "WhitespaceSplit": return new WhitespaceSplit(config); case "Metaspace": return new MetaspacePreTokenizer(config); default: throw new IllegalArgumentException("Unknown PreTokenizer type: " + type); } } public abstract List<String> pre_tokenize_text(String text, int section_index); public List<String> pre_tokenize(String text, int section_index) { return pre_tokenize_text(text, section_index); } }

5. PreTokenizerSequence.java

package com.transformers.tokenizer; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import java.util.ArrayList; import java.util.List; public class PreTokenizerSequence extends PreTokenizer { private List<PreTokenizer> tokenizers; public PreTokenizerSequence(JSONObject config) throws JSONException { JSONArray pretokenizersArray = config.getJSONArray("pretokenizers"); this.tokenizers = new ArrayList<>(); for (int i = 0; i < pretokenizersArray.length(); i++) { JSONObject tokenizerConfig = pretokenizersArray.getJSONObject(i); this.tokenizers.add(PreTokenizer.fromConfig(tokenizerConfig)); } } @Override public List<String> pre_tokenize_text(String text, int section_index) { List<String> preTokenizedText = new ArrayList<>(); preTokenizedText.add(text); for (PreTokenizer tokenizer : tokenizers) { List<String> newPreTokenizedText = new ArrayList<>(); for (String t : preTokenizedText) { newPreTokenizedText.addAll(tokenizer.pre_tokenize(t, section_index)); } preTokenizedText = new ArrayList<>(newPreTokenizedText); } return preTokenizedText; } }

6. WhitespaceSplit.java

package com.transformers.tokenizer; import org.json.JSONObject; import java.util.List; public class WhitespaceSplit extends PreTokenizer { public WhitespaceSplit(JSONObject config) { } @Override public List<String> pre_tokenize_text(String text, int section_index) { return PreTrainedTokenizer.whitespace_split(text); } }

7. MetaspacePreTokenizer.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; import java.util.Collections; import java.util.List; public class MetaspacePreTokenizer extends PreTokenizer { private boolean addPrefixSpace; private String replacement; private String strRep; private String prepend_scheme; public MetaspacePreTokenizer(JSONObject config) throws JSONException { this.addPrefixSpace = config.getBoolean("add_prefix_space"); this.replacement = config.getString("replacement"); this.strRep = config.optString("str_rep", this.replacement); this.prepend_scheme = config.optString("prepend_scheme", "always"); } @Override public List<String> pre_tokenize_text(String text, int section_index) { String normalized = text.replace(" ", this.strRep); if (this.addPrefixSpace && !normalized.startsWith(this.replacement) && (this.prepend_scheme.equals("always") || (this.prepend_scheme.equals("first") && section_index == 0))) { normalized = this.strRep + normalized; } return Collections.singletonList(normalized); } }

8. TokenizerModel.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; import java.util.*; public abstract class TokenizerModel { protected JSONObject config; protected List<String> vocab; protected Map<String, Integer> tokens_to_ids; protected Integer unk_token_id; protected String unk_token; protected String end_of_word_suffix; protected boolean fuse_unk; public TokenizerModel(JSONObject config) { this.config = config; this.vocab = new ArrayList<>(); this.tokens_to_ids = new HashMap<>(); this.unk_token_id = null; this.unk_token = null; this.end_of_word_suffix = null; this.fuse_unk = config.optBoolean("fuse_unk", false); } public static TokenizerModel fromConfig(JSONObject config, JSONObject tokenizerConfig) throws JSONException { String type = config.getString("type"); switch (type) { case "Unigram": return new Unigram(config, tokenizerConfig); default: throw new IllegalArgumentException("Unknown TokenizerModel type: " + type); } } public List<String> _call(List<String> tokens) { List<String> ids = this.encode(tokens); if (this.fuse_unk) { ids = fuse(ids, this.unk_token_id, this.tokens_to_ids); } return ids; } public abstract List<String> encode(List<String> tokens); public List<Integer> convert_tokens_to_ids(List<String> tokens) { List<Integer> ids = new ArrayList<>(); for (String t : tokens) { Integer id = this.tokens_to_ids.getOrDefault(t, this.unk_token_id); ids.add(id); } return ids; } public List<String> convert_ids_to_tokens(List<Integer> ids) { List<String> tokens = new ArrayList<>(); for (Integer i : ids) { String token = (i < this.vocab.size() && i >= 0) ? this.vocab.get(i) : this.unk_token; tokens.add(token); } return tokens; } public static List<String> fuse(List<String> arr, Integer value, Map<String, Integer> mapping) { List<String> fused = new ArrayList<>(); int i = 0; while (i < arr.size()) { fused.add(arr.get(i)); Integer mappedValue = mapping.getOrDefault(arr.get(i), value); if (!Objects.equals(mappedValue, value)) { i++; continue; } while (i < arr.size() && Objects.equals(mapping.getOrDefault(arr.get(i), value), value)) { i++; } } return fused; } }

9. Unigram.java

package com.transformers.tokenizer; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import java.util.ArrayList; import java.util.List; import java.util.Map; public class Unigram extends TokenizerModel { private double[] scores; private String bosToken; private Integer bosTokenId; private String eosToken; private Integer eosTokenId; private String unkToken; private double minScore; private double unkScore; private CharTrie trie; public Unigram(JSONObject config, JSONObject tokenizerConfig) throws JSONException { super(config); JSONArray vocabArray = config.getJSONArray("vocab"); int vocabSize = vocabArray.length(); this.vocab = new ArrayList<>(vocabSize); this.scores = new double[vocabSize]; for (int i = 0; i < vocabSize; ++i) { JSONArray piece = vocabArray.getJSONArray(i); this.vocab.add(piece.getString(0)); this.scores[i] = piece.getDouble(1); } this.unk_token_id = config.getInt("unk_id"); this.unk_token = this.vocab.get(this.unk_token_id); for (int i = 0; i < this.vocab.size(); i++) { this.tokens_to_ids.put(this.vocab.get(i), i); } this.bosToken = " "; this.bosTokenId = this.tokens_to_ids.get(this.bosToken); this.eosToken = tokenizerConfig.optString("eos_token"); this.eosTokenId = this.tokens_to_ids.get(this.eosToken); this.unkToken = this.vocab.get(this.unk_token_id); double[] minResult = min(this.scores); this.minScore = minResult[0]; this.unkScore = this.minScore - 10.0; this.scores[this.unk_token_id] = this.unkScore; this.trie = new CharTrie(); this.trie.extend(this.vocab); this.fuse_unk = true; } private void populateNodes(TokenLattice lattice) { String sentence = lattice.sentence; int len = sentence.length(); int beginPos = 0; while (beginPos < len) { int mblen = 1; boolean hasSingleNode = false; for (String token : this.trie.commonPrefixSearch(sentence.substring(beginPos))) { int tokenId = this.tokens_to_ids.get(token); double tokenScore = this.scores[tokenId]; int n = token.length(); lattice.insert(beginPos, n, tokenScore, tokenId); if (!hasSingleNode && n == mblen) { hasSingleNode = true; } } if (!hasSingleNode) { lattice.insert(beginPos, mblen, this.unkScore, this.unk_token_id); } beginPos += mblen; } } private List<String> tokenize(String normalized) { TokenLattice lattice = new TokenLattice(normalized, this.bosTokenId, this.eosTokenId); this.populateNodes(lattice); return lattice.tokens(); } @Override public List<String> encode(List<String> tokens) { List<String> toReturn = new ArrayList<>(); for (String token : tokens) { List<String> tokenized = this.tokenize(token); toReturn.addAll(tokenized); } return toReturn; } public static double[] min(double[] arr) { if (arr.length == 0) throw new IllegalArgumentException("Array must not be empty"); double min = arr[0]; int indexOfMin = 0; for (int i = 1; i < arr.length; ++i) { if (arr[i] < min) { min = arr[i]; indexOfMin = i; } } return new double[]{min, indexOfMin}; } }

10. CharTrie.java

package com.transformers.tokenizer; import java.util.ArrayList; import java.util.List; public class CharTrie { private CharTrieNode root; public CharTrie() { this.root = CharTrieNode.defaultNode(); } public void extend(List<String> texts) { if (texts.size() == 0) return; for (String text : texts) { CharTrieNode node = this.root; for (char ch : text.toCharArray()) { CharTrieNode child = node.children.get(ch); if (child == null) { child = CharTrieNode.defaultNode(); node.children.put(ch, child); } node = child; } node.isLeaf = true; } } public List<String> commonPrefixSearch(String text) { List<String> results = new ArrayList<>(); CharTrieNode node = this.root; StringBuilder prefix = new StringBuilder(); for (int i = 0; i < text.length() && node != null; ++i) { char ch = text.charAt(i); prefix.append(ch); node = node.children.get(ch); if (node != null && node.isLeaf) { results.add(prefix.toString()); } } return results; } }

11. CharTrieNode.java

package com.transformers.tokenizer; import java.util.HashMap; import java.util.Map; public class CharTrieNode { public boolean isLeaf; public Map<Character, CharTrieNode> children; public CharTrieNode(boolean isLeaf) { this.isLeaf = isLeaf; this.children = new HashMap<>(); } public static CharTrieNode defaultNode() { return new CharTrieNode(false); } }

12. PostProcessor.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; import java.util.Map; public abstract class PostProcessor { protected JSONObject config; public PostProcessor(JSONObject config) { this.config = config; } public static PostProcessor fromConfig(JSONObject config) throws JSONException { if (config == null) return null; String type = config.optString("type"); switch (type) { case "TemplateProcessing": return new TemplateProcessing(config); default: throw new IllegalArgumentException("Unknown PostProcessor type: " + type); } } public abstract Map<String, Object> post_process( List<String> tokens, List<String> tokens_pair, boolean add_special_tokens) throws JSONException; }

13. TemplateProcessing.java

package com.transformers.tokenizer; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; import java.util.*; public class TemplateProcessing extends PostProcessor { private List<Map<String, Object>> single; private List<Map<String, Object>> pair; public TemplateProcessing(JSONObject config) throws JSONException { super(config); this.single = parseTemplate(config.getJSONArray("single")); this.pair = parseTemplate(config.optJSONArray("pair")); } private List<Map<String, Object>> parseTemplate(JSONArray templateArray) throws JSONException { List<Map<String, Object>> template = new ArrayList<>(); for (int i = 0; i < templateArray.length(); i++) { JSONObject item = templateArray.getJSONObject(i); Map<String, Object> map = new HashMap<>(); if (item.has("SpecialToken")) { map.put("SpecialToken", item.getJSONObject("SpecialToken")); } else if (item.has("Sequence")) { map.put("Sequence", item.getJSONObject("Sequence")); } template.add(map); } return template; } @Override public Map<String, Object> post_process( List<String> tokens, List<String> tokens_pair, boolean add_special_tokens) throws JSONException { List<Map<String, Object>> type = (tokens_pair == null) ? this.single : this.pair; List<String> processedTokens = new ArrayList<>(); List<Integer> types = new ArrayList<>(); for (Map<String, Object> item : type) { if (item.containsKey("SpecialToken")) { if (add_special_tokens) { JSONObject specialToken = (JSONObject) item.get("SpecialToken"); processedTokens.add(specialToken.getString("id")); types.add(specialToken.getInt("type_id")); } } else if (item.containsKey("Sequence")) { JSONObject sequence = (JSONObject) item.get("Sequence"); String id = sequence.getString("id"); int typeId = sequence.getInt("type_id"); if (id.equals("A")) { processedTokens.addAll(tokens); for (int i = 0; i < tokens.size(); i++) { types.add(typeId); } } else if (id.equals("B") && tokens_pair != null) { processedTokens.addAll(tokens_pair); for (int i = 0; i < tokens_pair.size(); i++) { types.add(typeId); } } } } Map<String, Object> result = new HashMap<>(); result.put("tokens", processedTokens); result.put("token_type_ids", types); return result; } }

14. AddedToken.java

package com.transformers.tokenizer; import org.json.JSONException; import org.json.JSONObject; public class AddedToken { public String content; public int id; public boolean single_word; public boolean lstrip; public boolean rstrip; public boolean special; public Boolean normalized; public AddedToken(JSONObject config) throws JSONException { this.content = config.getString("content"); this.id = config.getInt("id"); this.single_word = config.optBoolean("single_word", false); this.lstrip = config.optBoolean("lstrip", false); this.rstrip = config.optBoolean("rstrip", false); this.special = config.optBoolean("special", false); this.normalized = config.has("normalized") ? config.getBoolean("normalized") : null; } }

15. TokenLattice.java

package com.transformers.tokenizer; import java.util.ArrayList; import java.util.List; public class TokenLattice { public String sentence; public int len; public int bosTokenId; public int eosTokenId; public List<TokenLatticeNode> nodes; public List<List<TokenLatticeNode>> beginNodes; public List<List<TokenLatticeNode>> endNodes; public TokenLattice(String sentence, Integer bosTokenId, Integer eosTokenId) { this.sentence = sentence; this.len = sentence.length(); this.bosTokenId = bosTokenId; this.eosTokenId = eosTokenId; this.nodes = new ArrayList<>(); this.beginNodes = new ArrayList<>(len + 1); this.endNodes = new ArrayList<>(len + 1); for (int i = 0; i <= len; i++) { this.beginNodes.add(new ArrayList<>()); this.endNodes.add(new ArrayList<>()); } TokenLatticeNode bos = new TokenLatticeNode(this.bosTokenId, 0, 0, 0, 0.0); TokenLatticeNode eos = new TokenLatticeNode(this.eosTokenId, 1, len, 0, 0.0); this.nodes.add(bos.clone()); this.nodes.add(eos.clone()); this.beginNodes.get(len).add(eos); this.endNodes.get(0).add(bos); } public void insert(int pos, int length, double score, int tokenId) { int nodeId = this.nodes.size(); TokenLatticeNode node = new TokenLatticeNode(tokenId, nodeId, pos, length, score); this.beginNodes.get(pos).add(node); this.endNodes.get(pos + length).add(node); this.nodes.add(node); } public List<TokenLatticeNode> viterbi() { int len = this.len; int pos = 0; while (pos <= len) { if (this.beginNodes.get(pos).size() == 0) { return new ArrayList<>(); } for (TokenLatticeNode rnode : this.beginNodes.get(pos)) { rnode.prev = null; double bestScore = 0.0; TokenLatticeNode bestNode = null; for (TokenLatticeNode lnode : this.endNodes.get(pos)) { double score = lnode.backtraceScore + rnode.score; if (bestNode == null || score > bestScore) { bestNode = lnode.clone(); bestScore = score; } } if (bestNode != null) { rnode.prev = bestNode; rnode.backtraceScore = bestScore; } else { return new ArrayList<>(); } } ++pos; } List<TokenLatticeNode> results = new ArrayList<>(); TokenLatticeNode root = this.beginNodes.get(len).get(0); TokenLatticeNode prev = root.prev; if (prev == null) { return new ArrayList<>(); } TokenLatticeNode node = prev.clone(); while (node.prev != null) { results.add(node.clone()); node = node.prev.clone(); } Collections.reverse(results); return results; } public String piece(TokenLatticeNode node) { return this.sentence.substring(node.pos, node.pos + node.length); } public List<String> tokens() { List<TokenLatticeNode> nodes = this.viterbi(); List<String> tokens = new ArrayList<>(); for (TokenLatticeNode node : nodes) { tokens.add(this.piece(node)); } return tokens; } public List<Integer> tokenIds() { List<TokenLatticeNode> nodes = this.viterbi(); List<Integer> tokenIds = new ArrayList<>(); for (TokenLatticeNode node : nodes) { tokenIds.add(node.tokenId); } return tokenIds; } }

16. TokenLatticeNode.java

package com.transformers.tokenizer; public class TokenLatticeNode { public int tokenId; public int nodeId; public int pos; public int length; public double score; public TokenLatticeNode prev; public double backtraceScore; public TokenLatticeNode(int tokenId, int nodeId, int pos, int length, double score) { this.tokenId = tokenId; this.nodeId = nodeId; this.pos = pos; this.length = length; this.score = score; this.prev = null; this.backtraceScore = 0.0; } public TokenLatticeNode clone() { TokenLatticeNode n = new TokenLatticeNode(this.tokenId, this.nodeId, this.pos, this.length, this.score); n.prev = this.prev; n.backtraceScore = this.backtraceScore; return n; } }

Note:

  • Make sure to handle exceptions appropriately, especially JSONException and any potential NullPointerException.
  • The code above assumes that you have access to Java 8 or higher for certain features like streams and lambdas.
  • For Android development, ensure that your build.gradle file is set up to support Java 8 features.

Usage Example in Your Module:

You can now use this tokenizer in your TokenizerModule as you intended:

// Inside your TokenizerModule class @ReactMethod public void encode(String text, Promise promise) { executorService.execute(() -> { try { if (tokenizer == null) { throw new IllegalStateException("Tokenizer not initialized. Call initializeTokenizer first."); } List<Integer> encodedIds = tokenizer.encode(text); WritableArray result = Arguments.createArray(); for (Integer id : encodedIds) { result.pushInt(id); } new Handler(Looper.getMainLooper()).post(() -> promise.resolve(result) ); } catch (IllegalStateException e) { new Handler(Looper.getMainLooper()).post(() -> promise.reject("TOKENIZER_NOT_INITIALIZED", e.getMessage()) ); } catch (Exception e) { new Handler(Looper.getMainLooper()).post(() -> promise.reject("ENCODING_ERROR", e.getMessage()) ); } }); }

Final Notes:

  • Be sure to test each component thoroughly.
  • Depending on the complexity of your tokenizer configurations, you might need to adjust some parts of the code, especially where JSON parsing is involved.
  • Ensure that all dependencies are properly imported and that your project is configured to handle them.

Feel free to ask if you need any further assistance or clarifications!

Share this Q&A