Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions modules/data.remote/inst/predict_and_store.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
##loops through all the parcel ids to make predictions for each from year 2024 - 2040 or something
##store the predictions in a new file, then graph a sample.

setwd("/projectnb/dietzelab/ananyak")
library(ggplot2)
library(data.table)

seq_long = fread('seq_long.csv')
season_idx = 1
end_year = 2030

##read tmat file but make it numeric again and fix formatting
tmat_df = fread('full_transition_matrix.csv')
states = colnames(tmat_df)[-1]
tmat_final = as.matrix(tmat_df[, -1, with = FALSE])
rownames(tmat_final) = tmat_df$V1
colnames(tmat_final) = states
storage.mode(tmat_final) = "double"

#check
print(head(rownames(tmat_final)))
print(head(colnames(tmat_final)))

setDT(seq_long)

tmat_year = tmat_final

start_info = seq_long[
season == season_idx,
.SD[which.max(year)],
by = parcel_id
]

#clean classes and set in same order as transition matrix
start_info[, CLASS := trimws(as.character(CLASS))]
states = rownames(tmat_final)

##all predictions
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the relationship between "predictions_with_tmatrix.R" and the code in this file? It doesn't look like you're using those functions here. I'd recommend either using those functions or converting the core of this code into a function, and then have the script here focus on reading specific data, calling the prediction function, and then performing the validation. Indeed, such a script might be better reformatted as a Rmd vignette.

all_preds = start_info[, {

p = setNames(rep(0, length(states)), states)

idx0 = match(CLASS, states)
if (is.na(idx0)) return(NULL)

p[idx0] = 1

years = seq(year + 1, end_year) # FIXED

preds = character(length(years))
probs = numeric(length(years))

for (i in seq_along(years)) {
p = as.numeric(p %*% tmat_year)

idx = sample(seq_along(p), 1, prob = p)

preds[i] = states[idx]
probs[i] = p[idx]
}

.(
season = season_idx,
year = years,
pred_class = preds,
pred_prob = probs
)

}, by = parcel_id]

#store actual classes from 2018-2023
actual_hist = seq_long[
season == season_idx,
.(parcel_id, year, pred_class = NA, actual_class = CLASS)
]

#2024 to end year (right now is 2030)
preds_future = all_preds[, .(
parcel_id,
year,
pred_class,
actual_class = NA
)]

#comboine to plot both for comparisons
plot_data = rbind(actual_hist, preds_future, fill = TRUE)

sample_pids = sample(unique(plot_data$parcel_id), 1)

plot_subset = plot_data[parcel_id %in% sample_pids]

plot_subset[, pred_class := factor(pred_class, levels = states)]
plot_subset[, actual_class := factor(actual_class, levels = states)]

ggplot(plot_subset, aes(x = year)) +

geom_point(
data = plot_subset[!is.na(pred_class)],
aes(y = pred_class, color = pred_class),
size = 3
) +

geom_point(
data = plot_subset[!is.na(actual_class)],
aes(y = actual_class),
shape = 1,
size = 3,
color = "black"
) +

facet_wrap(~parcel_id, ncol = 2) +

scale_x_continuous(
limits = c(2018, end_year),
breaks = seq(2018, end_year, by = 1)
) +

ggtitle(sprintf("Predicted and Actual crop classes for parcel %s", sample_pids)) +
theme_minimal()

93 changes: 93 additions & 0 deletions modules/data.remote/inst/predictions_with_tmatrix.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
library(expm)

predict_k_steps = function(tmat, current_state, k) {
if (!is.matrix(tmat)) tmat = as.matrix(tmat)
if (is.null(rownames(tmat)) || is.null(colnames(tmat)))
stop("tmat must have row/col names.")
if (!(current_state %in% rownames(tmat)))
stop("State not in matrix.")

init = rep(0, nrow(tmat)); names(init) = rownames(tmat)
init[current_state] = 1

pk = tmat %^% k
out = as.numeric(init %*% pk)
names(out) = colnames(tmat)
out
}

get_state_at = function(pid, seq_long_dt, yr, season_idx) {
st = seq_long_dt[parcel_id == pid & year == yr & season == season_idx, CLASS]
if (length(st) == 0) stop("No observation for this parcel at that (year, season).")
if (length(st) > 1) stop("Multiple rows found; check duplicates.")
st
}

# one prediction per year for focus season, anchored at latest observed year for that season
predict_yearly = function(tmat, pid, seq_long_dt,
end_year, season_idx,
anchor_year = NULL,
return_probs = FALSE) {

if (!is.matrix(tmat)) tmat = as.matrix(tmat)
stopifnot(all(rownames(tmat) == colnames(tmat)))

season_idx = as.integer(season_idx)
if (!(season_idx %in% 1:4)) stop("season_idx must be 1..4")

# find anchor year (latest observed year for that parcel+season),
if (is.null(anchor_year)) {
yrs_avail = seq_long_dt[parcel_id == pid & season == season_idx, unique(year)]
if (length(yrs_avail) == 0) stop("No observations for this parcel_id at that season.")
anchor_year = max(yrs_avail, na.rm = TRUE)
}

cur = get_state_at(pid, seq_long_dt, anchor_year, season_idx)

years = seq.int(anchor_year, end_year)
k_vec = (years - anchor_year) * 4 # same season each year = 4 steps per year

preds = character(length(years))
top_p = numeric(length(years))
probs_list = if (return_probs) vector("list", length(years)) else NULL

for (i in seq_along(years)) {
k = k_vec[i]

if (k == 0) {
p = rep(0, nrow(tmat)); names(p) = rownames(tmat); p[cur] = 1
} else {
p = predict_k_steps(tmat, cur, k)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend that you rethink what you're doing here to make the predictions iteratively (use the last prediction as the input to the next prediction) rather than repeating the prediction from 0 to k each time k increases by 1

}

preds[i] = names(which.max(p))
top_p[i] = max(p)
if (return_probs) probs_list[[i]] = p
}

out = data.table(
parcel_id = pid,
season = season_idx,
year = years,
steps_ahead = k_vec,
anchor_year = anchor_year,
anchor_state = cur,
pred_class = preds,
pred_prob = top_p
)
if (return_probs) out[, probs := probs_list]
out
}

#anchor at latest observed year for season s, then predict to future year
test = predict_yearly(
tmat = tmat_final,
pid = "1",
seq_long_dt = seq_long,
end_year = 2030,
season_idx = 3,
return_probs = TRUE
)

print(test)
paste(test$pred_class, collapse = "-")
110 changes: 110 additions & 0 deletions modules/data.remote/inst/transition_matrix.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
ry(data.table)
library(stringr)

file_2018 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2018.rds")
file_2019 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2019.rds")
file_2020 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2020.rds")
file_2021 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2021.rds")
file_2022 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2022.rds")
file_2023 = readRDS("/projectnb/dietzelab/ananyak/annual_landiq_PFT_2023.rds")

crops_full = rbind(file_2018, file_2019, file_2020, file_2021, file_2022, file_2023)
setDT(crops_full)

setorder(crops_full, parcel_id, year, season)

crop_sequences = crops_full[, .(
crop_sequence = paste(CLASS, collapse = "-")
), by = .(parcel_id, year)]


#merging rules
fix_seq = function(seq) {
parts = strsplit(seq, "-", fixed = TRUE)[[1]]
n = length(parts)

if (n > 1) for (i in 2:n) if (parts[i] == "X") parts[i] = parts[i - 1]
if (n > 1) for (i in 2:n) if (parts[i - 1] == "YP" && parts[i] == "**") parts[i] = "P"
if (n > 1) for (i in 1:(n - 1)) if (parts[i] %in% c("**", "X") && parts[i + 1] == "P") parts[i] = "YP"

vals = parts[parts != "**"]
u = unique(vals)
if (length(u) == 1 && all(parts %in% c(u, "**"))) parts = rep(u, n)

paste(parts, collapse = "-")
}

drop_sequences = c("**-**-**-**", "U-U-U-U", "UL-UL-UL-UL")
crop_sequences = crop_sequences[!crop_sequence %chin% drop_sequences]
crop_sequences[, crop_sequence := vapply(crop_sequence, fix_seq, character(1))]

##transition format df for matrix
#this unfortunately takes a while, this was the only way I could think of writing this
#saved seq_long (and final transition matrix) as a csv at bottom so this only has to be run once
#the prediction file (predict_and_stroe) reloads the saved files

seq_long = crop_sequences[, {
parts = strsplit(crop_sequence, "-", fixed = TRUE)[[1]]
data.table(season = seq_along(parts), CLASS = parts)
}, by = .(parcel_id, year)]

setorder(seq_long, parcel_id, year, season)

seq_long[, `:=`(
from = CLASS,
to = shift(CLASS, type = "lead"),
next_year = shift(year, type = "lead")
), by = parcel_id]

transitions_full = seq_long[
!is.na(to) & next_year == year + 1 & season == season_idx,
.(N = .N),
by = .(from, to)]

#build matrix
states_all = c("**","V","P","X","G","YP","U","D","C","I","T","F","R","UL")

tmat_counts = dcast(transitions_full, from ~ to, value.var = "N", fill = 0)

# add missing columns
missing_cols = setdiff(states_all, colnames(tmat_counts))
for (mc in missing_cols) tmat_counts[[mc]] = 0

# add missing rows
missing_rows = setdiff(states_all, tmat_counts$from)
if (length(missing_rows) > 0) {
zero_rows = data.table(from = missing_rows)
for (s in states_all) zero_rows[[s]] = 0
tmat_counts = rbind(tmat_counts, zero_rows, fill = TRUE)
}

# order matrix
tmat_counts[, ord := match(from, states_all)]
setorder(tmat_counts, ord)
tmat_counts[, ord := NULL]
tmat_counts = tmat_counts[, c("from", states_all), with = FALSE]


#normalize
rn = tmat_counts$from
prob_mat = as.matrix(tmat_counts[, ..states_all])
storage.mode(prob_mat) = "double"

#smoothing
prob_mat = prob_mat + 1e-3

#normalize again
tmat_final = prob_mat / rowSums(prob_mat)

rownames(tmat_final) = rn
colnames(tmat_final) = states_all

#checks
stopifnot(all(rownames(tmat_final) == states_all))
stopifnot(all(colnames(tmat_final) == states_all))

##save final transition matrix and huge transition sequence file (seq_long)
write.csv(seq_long, 'seq_long.csv')
write.csv(tmat_final, 'full_transition_matrix.csv')


Loading