## ----setup, include = FALSE, warning = FALSE----------------------------------
knitr::opts_chunk$set(comment = FALSE, 
                      warning = FALSE, 
                      message = FALSE)

## -----------------------------------------------------------------------------
library(scBubbletree)
library(ggplot2)
library(ggtree)
library(patchwork)

## -----------------------------------------------------------------------------
# # This script can be used to generate data("d_ccl", package = "scBubbletree")
# 
# # create directory
# dir.create(path = "case_study/")
# 
# # download the data from:
# https://github.com/LuyiTian/sc_mixology/raw/master/data/
#   sincell_with_class_5cl.RData
# 
# # load the data
# load(file = "case_study/sincell_with_class_5cl.RData")
# 
# # we are only interested in the 10x data object 'sce_sc_10x_5cl_qc'
# d <- sce_sc_10x_5cl_qc
# 
# # remove the remaining objects (cleanup)
# rm(sc_Celseq2_5cl_p1, sc_Celseq2_5cl_p2, sc_Celseq2_5cl_p3, sce_sc_10x_5cl_qc)
# 
# # get the meta data for each cell
# meta <- colData(d)[,c("cell_line_demuxlet","non_mt_percent","total_features")]
# 
# # create Seurat object from the raw counts and append the meta data to it
# d <- Seurat::CreateSeuratObject(counts = d@assays$data$counts,
#                                 project = '')
# 
# # check if all cells are matched between d and meta
# # table(rownames(d@meta.data) == meta@rownames)
# d@meta.data <- cbind(d@meta.data, meta@listData)
# 
# # cell type predictions are provided as part of the meta data
# table(d@meta.data$cell_line)
# 
# # select 5,000 most variable genes
# d <- Seurat::FindVariableFeatures(object = d,
#                                   selection.method = "vst",
#                                   nfeatures = 5000)
# 
# # Preprocessing with Seurat: SCT transformation + PCA
# d <- SCTransform(object = d,
#                  variable.features.n = 5000)
# d <- RunPCA(object = d,
#             npcs = 50,
#             features = VariableFeatures(object = d))
# 
# # perform UMAP + t-SNE
# d <- RunUMAP(d, dims = 1:15)
# d <- RunTSNE(d, dims = 1:15)
# 
# # save the preprocessed data
# save(d, file = "case_study/d.RData")
# 
# # save the PCA matrix 'A', meta data 'm' and
# # marker genes matrix 'e'
# d <- get(load(file ="case_study/d.RData"))
# A <- d@reductions$pca@cell.embeddings[, 1:15]
# m <- d@meta.data
# e <- t(as.matrix(d@assays$SCT@data[
#   rownames(d@assays$SCT@data) %in%
#     c("ALDH1A1",
#       "PIP4K2C",
#       "SLPI",
#       "CT45A2",
#       "CD74"), ]))
# 
# d_ccl <- list(A = A, m = m, e = e)
# save(d_ccl, file = "data/d_ccl.RData")

## -----------------------------------------------------------------------------
data("d_ccl", package = "scBubbletree")
A <- d_ccl$A
m <- d_ccl$m
e <- d_ccl$e

## -----------------------------------------------------------------------------
# A has n=cells as rows, f=features as columns (e.g. from PCA)
dim(A)

## -----------------------------------------------------------------------------
b_k <- get_k(B_gap = 5,
             ks = 1:10,
             x = A,
             n_start = 50, 
             iter_max = 200,
             kmeans_algorithm = "MacQueen", 
             cores = 1)

## -----------------------------------------------------------------------------
g0 <- ggplot(data = b_k$gap_stats_summary)+
  geom_line(aes(x = k, y = gap_mean))+
  geom_point(aes(x = k, y = gap_mean), size = 1)+
  geom_errorbar(aes(x = k, y = gap_mean, ymin = L95, ymax = H95), width = 0.1)+
  ylab(label = "Gap")|
ggplot(data = b_k$wcss_stats_summary)+
  geom_line(aes(x = k, y = wcss_mean))+
  geom_point(aes(x = k, y = wcss_mean), size = 1)+
  ylab(label = "WCSS")+
  scale_y_log10()+
  annotation_logticks(base = 10, sides = "l")

## ----fig.width=7, fig.height=2.5, fig.align='center'--------------------------
g0

## -----------------------------------------------------------------------------
b_r <- get_r(B_gap = 5,
             rs = 10^seq(from = -4, to = 0, by = 0.5),
             x = A,
             n_start = 10, 
             iter_max = 50,
             algorithm = "original",
             knn_k = 50,
             cores = 1)

## -----------------------------------------------------------------------------
g0_r <- (ggplot(data = b_r$gap_stats_summary)+
  geom_line(aes(x = k, y = gap_mean))+
  geom_point(aes(x = k, y = gap_mean), size = 1)+
  geom_errorbar(aes(x = k, y = gap_mean, ymin = L95, ymax = H95), width = 0.1)+
  ylab(label = "Gap")+
  xlab(label = "k'")|
ggplot(data = b_r$gap_stats_summary)+
  geom_line(aes(x = r, y = gap_mean))+
  geom_point(aes(x = r, y = gap_mean), size = 1)+
  geom_errorbar(aes(x = r, y = gap_mean, ymin = L95, ymax = H95), width = 0.1)+
  ylab(label = "Gap")+
  xlab(label = "r")+
  scale_x_log10()+
  annotation_logticks(base = 10, sides = "b"))/
(ggplot(data = b_r$wcss_stats_summary)+
  geom_line(aes(x = k, y = wcss_mean))+
  geom_point(aes(x = k, y = wcss_mean), size = 1)+
  ylab(label = "WCSS")+
  xlab(label = "k'")|
ggplot(data = b_r$wcss_stats_summary)+
  geom_line(aes(x = r, y = wcss_mean))+
  geom_point(aes(x = r, y = wcss_mean), size = 1)+
  ylab(label = "WCSS")+
  xlab(label = "r")+
  scale_x_log10()+
  annotation_logticks(base = 10, sides = "b"))

## ----fig.width=6, fig.height=4, fig.align='center'----------------------------
g0_r

## ----fig.width=4, fig.height=3, fig.align='center'----------------------------
ggplot(data = b_r$gap_stats_summary)+
    geom_point(aes(x = r, y = k), size = 1)+
    xlab(label = "r")+
    ylab(label = "k'")+
  scale_x_log10()+
  annotation_logticks(base = 10, sides = "b")

## -----------------------------------------------------------------------------
knitr::kable(x = b_r$gap_stats_summary[b_r$gap_stats_summary$k == 5, ],
             digits = 4, row.names = FALSE)

## -----------------------------------------------------------------------------
k5_kmeans <- get_bubbletree_kmeans(
  x = A,
  k = 5,
  cores = 1,
  B = 200,
  N_eff = 200,
  round_digits = 1,
  show_simple_count = FALSE,
  kmeans_algorithm = "MacQueen")

## ----fig.width=4, fig.height=3, fig.align='center'----------------------------
k5_kmeans$tree

## -----------------------------------------------------------------------------
knitr::kable(k5_kmeans$tree_meta, 
             digits = 2, row.names = FALSE)

## -----------------------------------------------------------------------------
k5_louvain <- get_bubbletree_graph(x = A,
                                   r = 0.1,
                                   n_start = 20,
                                   iter_max = 100,
                                   algorithm = "original",
                                   knn_k = 50,
                                   cores = 1,
                                   B = 200,
                                   N_eff = 200,
                                   round_digits = 1,
                                   show_simple_count = FALSE)

## ----fig.width=4, fig.height=3, fig.align='center'----------------------------
k5_louvain$tree

## ----fig.width=7, fig.height=3, fig.align='center'----------------------------
k5_kmeans$tree|k5_louvain$tree

## ----fig.width=7, fig.height=4, fig.align='center'----------------------------
w1 <- get_cat_tiles(btd = k5_kmeans,
                    f = m$cell_line_demuxlet,
                    integrate_vertical = TRUE,
                    round_digits = 1,
                    x_axis_name = 'Cell line',
                    rotate_x_axis_labels = TRUE,
                    tile_text_size = 2.75)

(k5_kmeans$tree|w1$plot)+
  patchwork::plot_layout(widths = c(1, 1))

## ----fig.width=7, fig.height=4, fig.align='center'----------------------------
w2 <- get_cat_tiles(btd = k5_kmeans,
                    f = m$cell_line_demuxlet,
                    integrate_vertical = FALSE,
                    round_digits = 1,
                    x_axis_name = 'Cell line',
                    rotate_x_axis_labels = TRUE,
                    tile_text_size = 2.75)

(k5_kmeans$tree|w2$plot)+
  patchwork::plot_layout(widths = c(1, 1))

## ----fig.width=9, fig.height=4, fig.align='center'----------------------------
(k5_kmeans$tree|w1$plot|w2$plot)+
  patchwork::plot_layout(widths = c(1, 2, 2))+
  patchwork::plot_annotation(tag_levels = "A")

## -----------------------------------------------------------------------------
# gini
get_gini(labels = m$cell_line_demuxlet, 
         clusters = k5_kmeans$cluster)$gi

## -----------------------------------------------------------------------------
gini_boot <- get_gini_k(labels = m$cell_line_demuxlet,
                        obj = b_k)

## ----fig.width=4, fig.height=3, fig.align='center'----------------------------
g1 <- ggplot(data = gini_boot$wgi_summary)+
  geom_point(aes(x = k, y = wgi), size = 0.75)+
  ylab(label = "WGI")+
  ylim(c(0, 1))

g1

## ----fig.width=8, fig.height=4, fig.align='center'----------------------------
w3 <- get_num_tiles(btd = k5_kmeans,
                    fs = e,
                    summary_function = "mean",
                    x_axis_name = 'Gene expression',
                    rotate_x_axis_labels = TRUE,
                    round_digits = 1,
                    tile_text_size = 2.75)

(k5_kmeans$tree|w3$plot)+
  patchwork::plot_layout(widths = c(1, 1))

## ----fig.width=10, fig.height=4, fig.align='center'---------------------------
w4 <- get_num_violins(btd = k5_kmeans,
                      fs = e,
                      x_axis_name = 'Gene expression',
                      rotate_x_axis_labels = TRUE)

(k5_kmeans$tree|w3$plot|w4$plot)+
  patchwork::plot_layout(widths = c(1.5, 2, 2.5))+
  patchwork::plot_annotation(tag_levels = 'A')

## ----fig.width=9, fig.height=4, fig.align='center'----------------------------
w_mt_dist <- get_num_violins(btd = k5_kmeans,
                             fs = 1-m$non_mt_percent,
                             x_axis_name = 'MT [%]',
                             rotate_x_axis_labels = TRUE)

w_umi_dist <- get_num_violins(btd = k5_kmeans,
                              fs = m$nCount_RNA/1000,
                              x_axis_name = 'RNA count (in thousands)',
                              rotate_x_axis_labels = TRUE)

w_gene_dist <- get_num_violins(btd = k5_kmeans,
                               fs = m$nFeature_RNA,
                               x_axis_name = 'Gene count',
                               rotate_x_axis_labels = TRUE)


(k5_kmeans$tree|w_mt_dist$plot|w_umi_dist$plot|w_gene_dist$plot)+
  patchwork::plot_layout(widths = c(0.7, 1, 1, 1))+
  patchwork::plot_annotation(tag_levels = 'A')

## ----fig.width=7, fig.height=4, fig.align='center'----------------------------
pam_k5 <- cluster::pam(x = A, 
                       k = 5, 
                       metric = "euclidean")

dummy_k5_pam <- get_bubbletree_dummy(x = A,
                                     cs = pam_k5$clustering,
                                     B = 200,
                                     N_eff = 200,
                                     cores = 2,
                                     round_digits = 1)

dummy_k5_pam$tree|
  get_cat_tiles(btd = dummy_k5_pam, 
                f = m$cell_line_demuxlet,
                integrate_vertical = TRUE,
                round_digits = 1,
                tile_text_size = 2.75,
                x_axis_name = 'Cell line',
                rotate_x_axis_labels = TRUE)$plot

## -----------------------------------------------------------------------------
sessionInfo()