%% Read in spike data
filename = 'p029_sort_final_01.nex';
spike = ft_read_spike(filename); 
 
cfg = [];
cfg.spikechannel = {'sig002a_wf', 'sig003a_wf'}; % Use 'all' or use 'gui' to see selectable channels if don't know which
spike = ft_spike_select(cfg, spike);
%% Read out LFP data using trial function
% Get the cfg.trl
cfg = [];
cfg.dataset = filename;
cfg.trialfun = 'trialfun_stimon_samples';
cfg = ft_definetrial(cfg);
 
% Read in the data in trials
cfg.channel = {'AD01', 'AD02', 'AD03', 'AD04'}; % These channels contain the LFP
cfg.padding = 10; % Length to which we pad for filtering
cfg.dftfreq = [60-1*(1/10):(1/10):60+1*(1/10) ]; % Filter out 60 hz line noise
cfg.dftfilter = 'yes';
data_lfp = ft_preprocessing(cfg); % Read in the LFP
%% Create trials for spike structure
cfg = [];
cfg.dataset = filename;
cfg.trialfun = 'trialfun_stimon_samples';
cfg = ft_definetrial(cfg);
trl = cfg.trl;
 
cfg = []; 
cfg.hdr = data_lfp.hdr; % Contains information for conversion of samples to timestamps
cfg.trlunit = 'samples';
cfg.trl = trl; % Now in samples
spikeTrials = ft_spike_maketrials(cfg,spike);
%% Link spike data to LFP data
data_all = ft_appendspike([],data_lfp, spike);
%% Correct for strong locking between spikes and high frequency (same-electrode) LFP components
cfg = [];
cfg.method = 'nan'; % replace the removed segment with nans
cfg.timwin = [-0.002 0.002]; % remove 4 ms around every spike
cfg.spikechannel = spike.label{1};
cfg.channel = data_lfp.label(2);
data_nan = ft_spiketriggeredinterpolation(cfg, data_all);
 
cfg.method = 'linear'; % remove the replaced segment with interpolation
data_i = ft_spiketriggeredinterpolation(cfg, data_all);
%% Compute spike triggered average (STA) of LFP (pre-stimulus)
clearvars cfg
cfg = [];
cfg.timwin = [-0.25 0.25]; % take 400 ms
cfg.spikechannel = spike.label{1}; % first unit
cfg.channel = data_lfp.label(1:4); % first four chans
cfg.latency = [-2.75 0];
staPre = ft_spiketriggeredaverage(cfg, data_all);
 
figure
plot(staPre.time, staPre.avg(:,:)')
legend(data_lfp.label)
xlabel('Time (s)')
xlim(cfg.timwin)
title('STA of LFP (Pre-stimulus)')
%% Compute spike triggered average (STA) of LFP (post-stimulus)
clearvars cfg
cfg = [];
cfg.timwin = [-0.25 0.25]; % take 400 ms
cfg.spikechannel = spike.label{1}; % first unit
cfg.channel = data_lfp.label(1:4); % first four chans
cfg.latency = [0.3 10];
staPost = ft_spiketriggeredaverage(cfg, data_all);
 
% Plot the STA
figure
plot(staPost.time, staPost.avg(:,:)')
legend(data_lfp.label)
xlabel('Time (s)')
xlim(cfg.timwin)
title('STA of LFP (Post-Stimulus)')