Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 57 additions & 108 deletions Rfunction/ModelAnalysisPlot.R
Original file line number Diff line number Diff line change
@@ -1,69 +1,53 @@
if(!require(ggplot2)) install.packages("ggplot2")
if(!require(dplyr)) install.packages("dplyr")
if(!require(patchwork)) install.packages("patchwork")

library(dplyr)
library(ggplot2)
library(patchwork)

ModelAnalysisPlot=function(trace_path,Stoch = F,print=T){
ModelAnalysisPlot=function(solverName_path,Stoch = F,print=T,ncol=2){

trace <-read.csv(trace_path,sep = "")
trace <-read.csv(solverName_path,sep = "")
n_sim_tot<-table(trace$Time)
n_sim <- n_sim_tot[1]
time_delete<-as.numeric(names(n_sim_tot[n_sim_tot!=n_sim_tot[1]]))

if(length(time_delete)!=0) trace = trace[which(trace$Time!=time_delete),]

trace$ID <- rep(1:n_sim[1],each = length(unique(trace$Time)) )

trace.final <- lapply(colnames(trace)[-which( colnames(trace)%in% c("ID","Time"))],function(c){
NoindexPlaces=which( colnames(trace)%in% c("ID","Time"))
trace.final <- lapply(colnames(trace)[-NoindexPlaces],function(c){
return(data.frame(V=trace[,c], ID = trace$ID,Time=trace$Time,Compartment=c ) )
})
trace.final <- do.call("rbind",trace.final)
#Create line plots
l_plot_line=lapply(colnames(trace)[-NoindexPlaces],function(namecl,trace){
return (ggplot( )+
geom_line(data=trace,
aes(x=Time,y=get(namecl),group=ID))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y=namecl))

},trace)
names(l_plot_line)=colnames(trace)[-NoindexPlaces]


plI<-ggplot( )+
geom_line(data=trace,
aes(x=Time,y=I,group=ID))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y="I")

plS<-ggplot( )+
geom_line(data=trace,
aes(x=Time,y=S,group=ID))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y="S")

plR<-ggplot( )+
geom_line(data=trace,
aes(x=Time,y=R,group=ID))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y="R")

if(Stoch){
meanTrace <- trace %>% group_by(Time) %>%
summarise(S=mean(S),I=mean(I),R=mean(R))

summarise_at(colnames(trace)[-NoindexPlaces], mean)
meanTrace.final <- lapply(colnames(meanTrace)[-which( colnames(meanTrace)=="Time")],function(c){
return(data.frame(V=unlist(meanTrace[,c]), Time=meanTrace$Time,Compartment=c ) )
})
meanTrace.final <- do.call("rbind",meanTrace.final)


meanTrace.final <- do.call("rbind",meanTrace.final)
plAll <-ggplot( )+
geom_line(data=trace.final,
aes(x=Time,y=V,group=ID))+
Expand All @@ -78,7 +62,7 @@ ModelAnalysisPlot=function(trace_path,Stoch = F,print=T){
legend.position="bottom",
legend.key.size = unit(1, "cm"),
legend.key.width = unit(1,"cm") )+
labs(x="Days", y="Population")
labs(x="Days", y="count")


plAllMean <-ggplot( )+
Expand All @@ -92,79 +76,44 @@ ModelAnalysisPlot=function(trace_path,Stoch = F,print=T){
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y="Mean Population")

plIdens<-ggplot(trace[trace$Time==max(trace$Time),])+
geom_histogram(aes(I))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )

plSdens<-ggplot(trace[trace$Time==max(trace$Time),])+
geom_histogram(aes(S))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )
labs(x="Days", y="Mean")

plRdens<-ggplot(trace[trace$Time==max(trace$Time),])+
geom_histogram(aes(R))+
#Create Histograms
l_plot_hist=lapply(colnames(trace)[-NoindexPlaces],function(namecl,trace){
return (ggplot(trace[trace$Time==max(trace$Time),])+
geom_histogram(aes(get(namecl)))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )

plI<-plI+
geom_line(data=meanTrace,
aes(x=Time,y=I,col="Mean"),
linetype="dashed")+
labs(x="Days", y="I",col="")

plS<-plS+
geom_line(data=meanTrace,
aes(x=Time,y=S,col="Mean"),
linetype="dashed")+
labs(x="Days", y="S",col="")
legend.key.width = unit(1.3,"cm") )+labs(x=namecl))
},trace)
names(l_plot_hist)=colnames(trace)[-NoindexPlaces]

plR<-plR+
geom_line(data=meanTrace,
aes(x=Time,y=R,col="Mean"),
linetype="dashed")+
labs(x="Days", y="R",col="")
#Adding mean in the line plot
l_plot_line_mean=lapply(colnames(trace)[-NoindexPlaces],function(namecl,meanTrace){
return (l_plot_line[[namecl]]+ geom_line(data=meanTrace,
aes(x=Time,y=get(namecl),col="Mean"),
linetype="dashed")+
labs(x="Days", y=namecl,col=""))
},meanTrace)

names(l_plot_line_mean)=colnames(trace)[-NoindexPlaces]

ListReturn<-list(plS = plS,plI = plI,plR = plR,
HistS = plSdens,HistI = plIdens,HistR = plRdens,
plAll=plAll,plAllMean=plAllMean)
(all_plot_line_mean=wrap_plots(l_plot_line_mean,ncol = ncol))
(all_plot_hist=wrap_plots(l_plot_hist,ncol = ncol))
if (print){
print(all_plot_line_mean)
print(all_plot_hist)
}
return(c(all_plot_line_mean,all_plot_hist))
}else{
plAll <-ggplot( )+
geom_line(data=trace.final,
aes(x=Time,y=V,col=Compartment))+
theme(axis.text=element_text(size=18),
axis.title=element_text(size=20,face="bold"),
legend.text=element_text(size=18),
legend.title=element_text(size=20,face="bold"),
legend.position="right",
legend.key.size = unit(1.3, "cm"),
legend.key.width = unit(1.3,"cm") )+
labs(x="Days", y="Population")
ListReturn<-list(plS = plS,plI = plI,plR = plR,plAll=plAll)
}

if(print){
for(j in 1:length(ListReturn))
print(ListReturn[j])
all_plot_line=wrap_plots(l_plot_line,ncol = ncol)
if (print){
print( all_plot_line)
}
return(all_plot_line)
}

return(ListReturn)
}