function [shift_filt] = mu_filt_rtms(data,pulsetime,diag)

% written by Sara Hussain as a modification of lcl_filt to use for prestim rtms data for SS_063; see Hussain et al. 2019 Cerebral Cortex

% one-pass, causal FIR filter (Blackmann-Harris window size=150/500ms, order=2500) and phase-delay correction based on window size
% performs clamping on post-stim period starting 10 samples prior to TMS pulse
% creates diagnostic plots for 5 randomly selected trials to check for phase consistency between raw and filtered signals
% also creates spaghetti plot for diagnostic purposes

%input arguments:
% - data: a time x trial matrix corresponding to the data to be filtered
% - pulsetime: the sample number during which the TMS pulse occurred in the input data matrix
% - diag: this is either 'yes' or 'no'; yes indicates to include diagnostic plots, 'no' leaves them out

%output arguments:
% - filt: a time x trial matrix corresponding to the filtered data

%set parameters
fs=5000;
datalength=size(data,1);
trialnum=size(data,2);
samples_per_ms=fs/1000;
ms_window=500;
window=10;

%set filter order (N) and create blackman harris window
N=(samples_per_ms*ms_window); %filter order
blackman_window=blackmanharris(N+1); %window size must be N+1
delay=N/2;

%create bandpass window normalized to the nyquist frequency
nyq=fs/2;
low=8/nyq;
high=12/nyq;

%design filter using parameters set above
flt=fir1(N,[low high],blackman_window);

%detrend and demean raw data before filtering
for i=1:trialnum
    data(:,i)=detrend(data(:,i),'constant');
end
for i=1:trialnum
    data(:,i)=detrend(data(:,i),'linear');
end

%clamp data, filter, correct delay
clamp_val=data(pulsetime-window,:);
clamp_dat=data;
for i=1:trialnum
    clamp_dat(pulsetime-window:datalength,i)=clamp_val(i);
end
for i=1:trialnum
    filtered_data(:,i)=filter(flt,1,clamp_dat(:,i));
end
for i=1:trialnum
    shift_filt(:,i)=circshift(filtered_data(:,i),-delay);
end

%check phase agreement using diagnostic plots
if strcmpi(diag,'yes') == 1

    random_trials=randi(trialnum,5);

    plotraw=clamp_dat(pulsetime-(fs*0.25):pulsetime+(fs*0.25),random_trials);
    plotfilt=shift_filt(pulsetime-(fs*0.25):pulsetime+(fs*0.25),random_trials);

    for i=1:5
        plotraw(:,i)=detrend(plotraw(:,i),'constant');
        plotraw(:,i)=detrend(plotraw(:,i),'linear');
    end

    figure
    subplot(5,1,1)
    plot(plotraw(:,1)); hold on; plot(plotfilt(:,1));

    subplot(5,1,2)
    plot(plotraw(:,2)); hold on; plot(plotfilt(:,2));

    subplot(5,1,3)
    plot(plotraw(:,3)); hold on; plot(plotfilt(:,3));

    subplot(5,1,4)
    plot(plotraw(:,4)); hold on; plot(plotfilt(:,4));

    subplot(5,1,5)
    plot(plotraw(:,5)); hold on; plot(plotfilt(:,5));

    uiwait(figure(1));
end

end
