Skip to content

Commit

Permalink
Merge pull request #122 from seart-group/feature/lookahead
Browse files Browse the repository at this point in the history
  • Loading branch information
dabico authored Feb 5, 2024
2 parents 3f69f3f + 910ae23 commit d1ad650
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 4 deletions.
18 changes: 18 additions & 0 deletions lib/ch_usi_si_seart_treesitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ jfieldID _treeCursorIdField;
jfieldID _treeCursorTreeField;
jmethodID _treeCursorConstructor;

jclass _lookaheadIteratorClass;
jfieldID _lookaheadIteratorHasNextField;
jfieldID _lookaheadIteratorLanguageField;
jmethodID _lookaheadIteratorConstructor;

jclass _noSuchElementExceptionClass;
jclass _nullPointerExceptionClass;
jclass _illegalArgumentExceptionClass;
jclass _illegalStateExceptionClass;
Expand Down Expand Up @@ -252,6 +258,12 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
_loadField(_treeCursorTreeField, _treeCursorClass, "tree", "Lch/usi/si/seart/treesitter/Tree;")
_loadConstructor(_treeCursorConstructor, _treeCursorClass, "(JIIJLch/usi/si/seart/treesitter/Tree;)V")

_loadClass(_lookaheadIteratorClass, "ch/usi/si/seart/treesitter/LookaheadIterator")
_loadField(_lookaheadIteratorHasNextField, _lookaheadIteratorClass, "hasNext", "Z")
_loadField(_lookaheadIteratorLanguageField, _lookaheadIteratorClass, "language", "Lch/usi/si/seart/treesitter/Language;")
_loadConstructor(_lookaheadIteratorConstructor, _lookaheadIteratorClass, "(JZLch/usi/si/seart/treesitter/Language;)V")

_loadClass(_noSuchElementExceptionClass, "java/util/NoSuchElementException")
_loadClass(_nullPointerExceptionClass, "java/lang/NullPointerException")
_loadClass(_illegalArgumentExceptionClass, "java/lang/IllegalArgumentException")
_loadClass(_illegalStateExceptionClass, "java/lang/IllegalStateException")
Expand Down Expand Up @@ -315,6 +327,8 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
_unload(_queryCursorClass)
_unload(_symbolClass)
_unload(_treeCursorClass)
_unload(_lookaheadIteratorClass)
_unload(_noSuchElementExceptionClass)
_unload(_nullPointerExceptionClass)
_unload(_illegalArgumentExceptionClass)
_unload(_illegalStateExceptionClass)
Expand All @@ -337,6 +351,10 @@ ComparisonResult intcmp(uint32_t x, uint32_t y) {
return (x < y) ? LT : ((x == y) ? EQ : GT);
}

jint __throwNSE(JNIEnv* env, const char* message) {
return _throwNew(_noSuchElementExceptionClass, message);
}

jint __throwNPE(JNIEnv* env, const char* message) {
return _throwNew(_nullPointerExceptionClass, message);
}
Expand Down
8 changes: 8 additions & 0 deletions lib/ch_usi_si_seart_treesitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern jclass _mapClass;
extern jclass _mapEntryClass;
extern jmethodID _mapEntryStaticMethod;

extern jclass _noSuchElementExceptionClass;
extern jclass _nullPointerExceptionClass;
extern jclass _illegalArgumentExceptionClass;
extern jclass _illegalStateExceptionClass;
Expand Down Expand Up @@ -116,6 +117,11 @@ extern jfieldID _treeCursorIdField;
extern jfieldID _treeCursorTreeField;
extern jmethodID _treeCursorConstructor;

extern jclass _lookaheadIteratorClass;
extern jfieldID _lookaheadIteratorHasNextField;
extern jfieldID _lookaheadIteratorLanguageField;
extern jmethodID _lookaheadIteratorConstructor;

extern jclass _treeSitterExceptionClass;

extern jclass _byteOffsetOutOfBoundsExceptionClass;
Expand Down Expand Up @@ -210,6 +216,8 @@ typedef enum {

ComparisonResult intcmp(uint32_t x, uint32_t y);

jint __throwNSE(JNIEnv* env, const char* message);

jint __throwNPE(JNIEnv* env, const char* message);

jint __throwIAE(JNIEnv* env, const char* message);
Expand Down
31 changes: 27 additions & 4 deletions lib/ch_usi_si_seart_treesitter_Language.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,12 +510,12 @@ JNIEXPORT jlong JNICALL Java_ch_usi_si_seart_treesitter_Language_zig(

JNIEXPORT jint JNICALL Java_ch_usi_si_seart_treesitter_Language_version(
JNIEnv* env, jclass self, jlong id) {
return (jint)ts_language_version((const TSLanguage *)id);
return (jint)ts_language_version((const TSLanguage*)id);
}

JNIEXPORT jint JNICALL Java_ch_usi_si_seart_treesitter_Language_symbols(
JNIEnv* env, jclass self, jlong id) {
return (jint)ts_language_symbol_count((const TSLanguage *)id);
return (jint)ts_language_symbol_count((const TSLanguage*)id);
}

JNIEXPORT jobject JNICALL Java_ch_usi_si_seart_treesitter_Language_symbol(
Expand All @@ -535,10 +535,33 @@ JNIEXPORT jobject JNICALL Java_ch_usi_si_seart_treesitter_Language_symbol(

JNIEXPORT jint JNICALL Java_ch_usi_si_seart_treesitter_Language_fields(
JNIEnv* env, jclass self, jlong id) {
return (jint)ts_language_field_count((const TSLanguage *)id);
return (jint)ts_language_field_count((const TSLanguage*)id);
}

JNIEXPORT jint JNICALL Java_ch_usi_si_seart_treesitter_Language_states(
JNIEnv* env, jclass self, jlong id) {
return (jint)ts_language_state_count((const TSLanguage *)id);
return (jint)ts_language_state_count((const TSLanguage*)id);
}

JNIEXPORT jobject JNICALL Java_ch_usi_si_seart_treesitter_Language_iterator(
JNIEnv* env, jobject thisObject, jint state) {
jclass _languageClass = env->GetObjectClass(thisObject);
jfieldID _languageIdField = env->GetFieldID(_languageClass, "id", "J");
TSLanguage* language = (TSLanguage*)env->GetLongField(thisObject, _languageIdField);
if (state < 0 || state >= ts_language_state_count(language)) {
__throwIAE(env, "Invalid parse state!");
return NULL;
}
TSLookaheadIterator* iterator = ts_lookahead_iterator_new(language, (TSStateId)state);
if (iterator == NULL) {
__throwISE(env, "Unable to create lookahead iterator!");
return NULL;
}
return env->NewObject(
_lookaheadIteratorClass,
_lookaheadIteratorConstructor,
(jlong)iterator,
ts_lookahead_iterator_next(iterator) ? JNI_TRUE : JNI_FALSE,
thisObject
);
}
8 changes: 8 additions & 0 deletions lib/ch_usi_si_seart_treesitter_Language.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions lib/ch_usi_si_seart_treesitter_LookaheadIterator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "ch_usi_si_seart_treesitter.h"
#include "ch_usi_si_seart_treesitter_LookaheadIterator.h"
#include <jni.h>
#include <tree_sitter/api.h>

JNIEXPORT void JNICALL Java_ch_usi_si_seart_treesitter_LookaheadIterator_delete(
JNIEnv* env, jobject thisObject) {
TSLookaheadIterator* iterator = (TSLookaheadIterator*)__getPointer(env, thisObject);
ts_lookahead_iterator_delete(iterator);
__clearPointer(env, thisObject);
}

JNIEXPORT jobject JNICALL Java_ch_usi_si_seart_treesitter_LookaheadIterator_next(
JNIEnv* env, jobject thisObject) {
bool hasNext = (bool)env->GetBooleanField(thisObject, _lookaheadIteratorHasNextField);
if (!hasNext) {
__throwNSE(env, NULL);
return NULL;
}
TSLookaheadIterator* iterator = (TSLookaheadIterator*)__getPointer(env, thisObject);
TSSymbol symbol = ts_lookahead_iterator_current_symbol(iterator);
const TSLanguage* language = ts_lookahead_iterator_language(iterator);
const char* name = ts_language_symbol_name(language, symbol);
TSSymbolType type = ts_language_symbol_type(language, symbol);
env->SetBooleanField(
thisObject,
_lookaheadIteratorHasNextField,
ts_lookahead_iterator_next(iterator) ? JNI_TRUE : JNI_FALSE
);
return env->NewObject(
_symbolClass,
_symbolConstructor,
(jint)symbol,
(jint)type,
env->NewStringUTF(name)
);
}
29 changes: 29 additions & 0 deletions lib/ch_usi_si_seart_treesitter_LookaheadIterator.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions src/main/java/ch/usi/si/seart/treesitter/Language.java
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,20 @@ public static void validate(@NotNull Language language) {
this.extensions = extensions;
}

/**
* Create a lookahead iterator, beginning from a specific parse state.
*
* @param state the parse state
* @return a lookahead iterator
* @throws IllegalArgumentException if:
* <ul>
* <li>{@code state} &lt; 0</li>
* <li>{@code state} &ge; {@link #totalStates}</li>
* </ul>
* @since 1.12.0
*/
public native LookaheadIterator iterator(int state);

@Generated
@SuppressWarnings("unused")
public int getTotalSymbols() {
Expand Down
40 changes: 40 additions & 0 deletions src/main/java/ch/usi/si/seart/treesitter/LookaheadIterator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ch.usi.si.seart.treesitter;

import lombok.Getter;
import lombok.experimental.FieldDefaults;

import java.util.Iterator;

/**
* Specialized iterator that can be used to generate suggestions and improve syntax error diagnostics.
* To get symbols valid in an {@code ERROR} node, use the lookahead iterator on its first leaf node state.
* For {@code MISSING} nodes, a lookahead iterator created on the previous non-extra leaf node may be appropriate.
*
* @since 1.12.0
* @author Ozren Dabić
*/
@FieldDefaults(level = lombok.AccessLevel.PRIVATE, makeFinal = true)
public class LookaheadIterator extends External implements Iterator<Symbol> {

boolean hasNext;

@Getter
Language language;

LookaheadIterator(long pointer, boolean hasNext, Language language) {
super(pointer);
this.hasNext = hasNext;
this.language = language;
}

@Override
protected native void delete();

@Override
public boolean hasNext() {
return hasNext;
}

@Override
public native Symbol next();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ch.usi.si.seart.treesitter;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.NoSuchElementException;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

class LookaheadIteratorTest extends TestBase {

private static final Language language = Language.JAVA;

@Test
void testIterator() {
try (LookaheadIterator iterator = language.iterator(0)) {
Spliterator<Symbol> spliterator = Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED);
List<Symbol> symbols = StreamSupport.stream(spliterator, false).collect(Collectors.toUnmodifiableList());
Assertions.assertFalse(symbols.isEmpty());
Assertions.assertTrue(language.getSymbols().containsAll(symbols));
Assertions.assertThrows(NoSuchElementException.class, iterator::next);
} catch (NoSuchElementException ignored) {
Assertions.fail();
}
}

@Test
@SuppressWarnings("resource")
void testIteratorThrows() {
int states = language.getTotalStates();
Assertions.assertThrows(IllegalArgumentException.class, () -> language.iterator(Integer.MIN_VALUE));
Assertions.assertThrows(IllegalArgumentException.class, () -> language.iterator(-1));
Assertions.assertThrows(IllegalArgumentException.class, () -> language.iterator(states));
Assertions.assertThrows(IllegalArgumentException.class, () -> language.iterator(Integer.MAX_VALUE));
}
}

0 comments on commit d1ad650

Please sign in to comment.