Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions lsproxy/src/ast_grep/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ const SYMBOL_CONFIG_PATH: &str = "/usr/src/ast_grep/symbol/config.yml";
const IDENTIFIER_CONFIG_PATH: &str = "/usr/src/ast_grep/identifier/config.yml";
const REFERENCE_CONFIG_PATH: &str = "/usr/src/ast_grep/reference/config.yml";

fn get_symbol_config_path() -> String {
std::env::var("AST_GREP_SYMBOL_CONFIG").unwrap_or_else(|_| SYMBOL_CONFIG_PATH.to_string())
}

fn get_identifier_config_path() -> String {
std::env::var("AST_GREP_IDENTIFIER_CONFIG").unwrap_or_else(|_| IDENTIFIER_CONFIG_PATH.to_string())
}

fn get_reference_config_path() -> String {
std::env::var("AST_GREP_REFERENCE_CONFIG").unwrap_or_else(|_| REFERENCE_CONFIG_PATH.to_string())
}

use super::types::AstGrepMatch;

pub struct AstGrepClient;
Expand All @@ -16,7 +28,7 @@ impl AstGrepClient {
identifier_position: &lsp_types::Position,
) -> Result<AstGrepMatch, Box<dyn std::error::Error>> {
// Get all symbols in the file
let file_symbols = self.scan_file(SYMBOL_CONFIG_PATH, file_name).await?;
let file_symbols = self.scan_file(&get_symbol_config_path(), file_name).await?;

// Find the symbol that matches our identifier position
let symbol_result = file_symbols.into_iter().find(|ast_symbol_match| {
Expand All @@ -43,14 +55,14 @@ impl AstGrepClient {
&self,
file_name: &str,
) -> Result<Vec<AstGrepMatch>, Box<dyn std::error::Error>> {
self.scan_file(SYMBOL_CONFIG_PATH, file_name).await
self.scan_file(&get_symbol_config_path(), file_name).await
}

pub async fn get_file_identifiers(
&self,
file_name: &str,
) -> Result<Vec<AstGrepMatch>, Box<dyn std::error::Error>> {
self.scan_file(IDENTIFIER_CONFIG_PATH, file_name).await
self.scan_file(&get_identifier_config_path(), file_name).await
}

pub async fn get_symbol_and_references(
Expand All @@ -75,7 +87,7 @@ impl AstGrepClient {
full_scan: bool,
) -> Result<Vec<AstGrepMatch>, Box<dyn std::error::Error>> {
// Get all references
let matches = self.scan_file(REFERENCE_CONFIG_PATH, file_name).await?;
let matches = self.scan_file(&get_reference_config_path(), file_name).await?;

// Filter matches to those within the symbol's range
// And if not full_scan, exclude matches with rule_id "non-function"
Expand Down
Loading