%% ECoG Pruner:
% Author: Alex Lepauvre
% Date: 13/11/2020
% This function plots the EDF triggers signal and automatically detects the
% onsets of levels to chop up the signal. The user can manually correct or
% select the levels onsets if they are off. As a result the original EDF file gets
% cut into serparate files.
% inputs:
% - EDF_in: full path + file name of the EDF to chop
% - PD_channel: channel of the photodiode or any other kind of trigger
% channel
% - ParticipantID: Id of the participant to save the file according to
% naming conventions
% EDF_out_path: Path where to save the data
% Outputs:
% - Saved EDF files

%% Main functions
function ECoGPruner(EDF_in, PD_channel, ParticipantID, EDF_out_path)

% Housekeeping:
close all

% Adding path to the functions:
addpath .\helperFunctions

%% Anonymization of the EDFs:
newFileName = anonEDF(EDF_in);

%% Set up and plotting:
% Loading the ECoG data:
[signal_pd, signal, SR] = loadData(newFileName, PD_channel);

% File names:
fileNames = {[ParticipantID, '_ECoG_V1_Loc.EDF'], [ParticipantID, '_ECoG_V1_DurR1.EDF'], ...
    [ParticipantID, '_ECoG_V1_DurR2.EDF'], [ParticipantID, '_ECoG_V1_DurR3.EDF'], ...
    [ParticipantID, '_ECoG_V1_DurR4.EDF'], [ParticipantID, '_ECoG_V1_DurR5.EDF']};

% Plotting the figurr:
figure(); hold on;
% Plotting PD signal:
sig = plot(signal_pd);
% Getting the axis of the plot for later use:
ax = gca;
% Getting the original scale of the plot to reset it when needed:
xlimOrig = xlim;
% Computing the most probable threshold:
PD_Threshold = (max(signal_pd) - min(signal_pd))*0.5 + min(signal_pd);
% Plotting the threshold:
thresh = yline(PD_Threshold,'r');

% Setting the loop controller:
pursue = 1;
% Looping until we are happy about the threshold
while pursue
    % Asking user input:
    resp = questdlg('Are you happy about this threshold?', ...
        'Threshold', ...
        'Yes','No','No');
    if strcmp(resp, 'Yes')
        % Removing the line because we don't need it:
        delete(thresh)
        % Terminating the loop:
        pursue = 0;
    else % If user not satisfied, removing the line and asking for a new one
        % Removing the previous line:
        delete(thresh)
        % Asking for a new line:
        Text1 = text(0.5,0.5,'Select a new threshold', 'Units', 'normalized',...
            'FontSize',14,'HorizontalAlignment','center');
        % Asking participant input:
        [~,y]=ginput(1);
        % Removing text:
        delete(Text1)
        % Setting the selected threshold
        PD_Threshold = y;
        % Plotting it:
        thresh = yline(PD_Threshold,'r');
    end
end

% Getting the photodiode onset:
photodiodeOnsets_idx = getPDOnsets(signal_pd,PD_Threshold);

% Now, finding the level onsets:
levelOnsets = findLevelOnsetsphotodiodeOnsets_idx(photodiodeOnsets_idx, SR);


%% Requesting user input to fine tune detection:
% Initiating the while loop:
happy = 'No';

% Giving the user a chance to review and iterate to improve if needed:
while strcmp(happy,'No')
    % If level onsets were detected
    if ~isempty(levelOnsets)
        
        % Plot the detected levels onsets:
        for i = 1:length(levelOnsets)
            vertLine(i) = xline(levelOnsets(i), '--r', 'LineWidth', 1);
        end
        
        % Asking user if they are okay
        % Querying answer from the different options:
        options = {'Take a Closer look to adjust', 'Remove All', 'Add new onsets', 'Continue to saving'};
        idx = listdlg('PromptString',{'What would you like to do?'},...
            'SelectionMode','single','ListString',options);
        % Getting the answer:
        answer = options{idx};
        
        switch answer
            % User can retake a closer view if needed
            case 'Take a Closer look to adjust'
                % Removing all the lines:
                for i = 1:length(levelOnsets)
                    delete(vertLine(i))
                end
                levelOnsets = closerInspection(levelOnsets, SR, ax);
                % User can add new triggers
            case 'Add new onsets'
                levelOnsets = fullManualInputs(1, levelOnsets);
                % User can remove all the lines and restart manually:
            case 'Remove All'
                % Removing all the lines:
                for i = 1:length(levelOnsets)
                    delete(vertLine(i))
                end
                levelOnsets = fullManualInputs(0);
                % If all is good, then we can move on:
            case 'Continue to saving'
                happy = 'Yes';
        end
        
    else % If no levels onsets were detected, then we go into the full manual detection:
        levelOnsets = fullManualInputs(0);
    end
    % Once the user corrected things, zooming out for last review
    set(ax,'Xlim',xlimOrig)
end

% Plotting the detected level onsets:
for i = 1:length(levelOnsets)
    vertLine(i) = xline(levelOnsets(i), '--r', 'LineWidth', 1); % plotting the first detected level onset
end

% Question: are the finger localizers in there?
fingerLocalizers = questdlg('Are the finger localizer data found before the first level onset?', ...
    'Finger localizer', ...
    'Yes', 'No', 'No');



%% Chopping the EDFs accordingly:

% Looping through the detected levels onsets:
if strcmp(fingerLocalizers, 'No')
    
    % Asking the user to confirm file names:
    prompt = {'File1:','File2:', 'File3(Leave empty if none):','File4(Leave empty if none):',...
        'File5(Leave empty if none):','File6(Leave empty if none):', 'File7(Leave empty if none):',...
        'File8(Leave empty if none):', 'File9(Leave empty if none):','File10(Leave empty if none):'};
    dlgtitle = 'Files naming';
    dims = [1 100];
    definput = [fileNames{2:end}, {'', '', '', '', ''}];
    answer = inputdlg(prompt,dlgtitle,dims,definput);
    
    for i = 1 : length(levelOnsets)
        
        % Keep only the first half (data is in a cell array)
        if i ~= length(levelOnsets)
            % Getting the sample where the block starst
            blockOnset_ind = levelOnsets(i);
            % Getting the sample of where the block ends:
            blockOffset_ind = levelOnsets(i+1)-1;
            % Chopping the signal that way:
            dataBlock = signal.trial{1,1}(:,blockOnset_ind:blockOffset_ind);
            % Setting the headers:
            headerBlock = signal.hdr;
            headerBlock.nSamples = length(blockOnset_ind:blockOffset_ind);
        else
            % Getting the sample where the block starst
            blockOnset_ind = levelOnsets(i);
            % Getting the sample of where the block ends:
            blockOffset_ind = length(signal.trial{1,1});
            % Chopping the signal that way:
            dataBlock =  signal.trial{1,1}(:,blockOnset_ind:blockOffset_ind);
            % Setting the headers:
            headerBlock = signal.hdr;
            headerBlock.nSamples = length(blockOnset_ind:blockOffset_ind);
        end
        
        % Save the file
        ft_write_data(fullfile(EDF_out_path,answer{i}), dataBlock, 'header', signal.hdr);
    end
else
    
    % Asking the user to confirm file names:
    prompt = {'File1:','File2:', 'File3(Leave empty if none):','File4(Leave empty if none):',...
        'File5(Leave empty if none):','File6(Leave empty if none):', 'File7(Leave empty if none):',...
        'File8(Leave empty if none):', 'File9(Leave empty if none):','File10(Leave empty if none):'};
    dlgtitle = 'Files naming';
    dims = [1 100];
    definput = [fileNames, {'', '', '', ''}];
    answer = inputdlg(prompt,dlgtitle,dims,definput);
    
    for i = 0 : length(levelOnsets)
        % Keep only the first half (data is in a cell array)
        if i == 0
            % Getting the sample where the block starst
            blockOnset_ind = 1;
            % Getting the sample of where the block ends:
            blockOffset_ind = levelOnsets(i+1)-1;
            % Chopping the signal that way:
            dataBlock =  signal.trial{1,1}(:,blockOnset_ind:blockOffset_ind);
            % The header also need to be rewritten to match data dim:
            headerBlock = signal.hdr;
            headerBlock.nSamples = length(blockOnset_ind:blockOffset_ind);
        elseif i ~= length(levelOnsets)
            % Getting the sample where the block starst
            blockOnset_ind = levelOnsets(i);
            % Getting the sample of where the block ends:
            blockOffset_ind = levelOnsets(i+1)-1;
            % Chopping the signal that way:
            dataBlock =  signal.trial{1,1}(:,blockOnset_ind:blockOffset_ind);
            % The header also need to be rewritten to match data dim:
            headerBlock = signal.hdr;
            headerBlock.nSamples = length(levelOnsets(i):levelOnsets(i+1)-1);
        else
            % Getting the sample where the block starst
            blockOnset_ind = levelOnsets(i);
            % Getting the sample of where the block ends:
            blockOffset_ind = size(signal.trial{1,1},2);
            % Chopping the signal that way:
            dataBlock =  signal.trial{1,1}(:,blockOnset_ind:blockOffset_ind);
            % The header also need to be rewritten to match data dim:
            headerBlock = signal.hdr;
            headerBlock.nSamples = length(blockOnset_ind:blockOffset_ind);
        end
        
        % Save the file
        ft_write_data(fullfile(EDF_out_path,answer{i+1}), dataBlock, 'header',headerBlock);
        
        % This function checks whether the data have been altered by the
        % saving procedure:
        checkDataIntegrity(fullfile(EDF_out_path,answer{i+1}), signal.trial{1,1}, blockOnset_ind, blockOffset_ind)
        
    end
end





%% Helper functions

%% loadData
function [signal_pd, signal, SR] = loadData(data, PD_channel)
% This function loads EDF files from the file we are interested in:
% Inputs:
% data: string with path to the data to load
% PD_channel: string with the name of the channel
% Outputs:
% signal_pd: signal of the trigger channel
% SR: sampling rate of the signal

% Declaring vars
cfg = [];
signal = [];

% Setting the data path
cfg.dataset = data;

% Loading the data:
signal = ft_preprocessing(cfg);

% Declaring the signal PD:
signal_pd = [];

% If the PD channel was specified, extract the signal of it:
if nargin > 1
    % Finding the photodiode index:
    PD_index = find(strcmp(signal.label,PD_channel));
    
    % Extracting triggers channel
    signal_pd = signal.trial{1,1}(PD_index,:);
end
% Getting the sampling rate
SR = signal.fsample;



%% getPDOnsets
function [photodiodeOnsets_idx]= getPDOnsets(Signal,Threshold)
% This function detects the photodiode triggers onset from the trigger
% signal
% Inputs:
% - Signal: Signal containing the triggers to extract
% - Threshold: threshold to parse the signal
% Outputs:
% - photodiodeOnsets_idx: index of the triggers onsets

% Making the signal binary to detect the threshold exceeded:
% I can now threshold the photo signal:
SigBin = Signal > Threshold;

% Now, I do the diff:
diffSigBin = diff(SigBin);
% Now, we have the onsets of when the threshold is exceed
photodiodeOnsets_idx = find(diffSigBin == 1);


%% findLevelOnsetsphotodiodeOnsets_idx
function levelOnsets = findLevelOnsetsphotodiodeOnsets_idx(photodiodeOnsets_idx, SR)
% This function detects the onset of the worlds to chop the signal, which
% are characterized by 4 successive triggers
% Inputs:
% photodiodeOnsets_idx: index of the triggers onsets
% SR: Sampling rate of the signal
% Outputs
% levelOnsets: indices of the levels onsets

% Ref rate of the screen in Sec
refRateSec = 0.016;

% Converting that to samples:
refRateSamples = refRateSec/(1/SR);

% Computing the interval between triggers:
diffPD_onsets = diff(photodiodeOnsets_idx);

% Declaring the variable:
levelOnsets = [];

% setting counter
ctr = 1;

% Looping through each interval to find where we have 4 successive ones:
for i = 1:length(diffPD_onsets)
    if i > 3
        if (diffPD_onsets(i) <= refRateSamples * 8 && diffPD_onsets(i) >= refRateSamples * 3) && ...
                (diffPD_onsets(i-1) <= refRateSamples * 8 && diffPD_onsets(i-1) >= refRateSamples * 3)&& ...
                (diffPD_onsets(i-2) <= refRateSamples * 8 && diffPD_onsets(i-2) >= refRateSamples * 3)
            
            levelOnsets(ctr) = photodiodeOnsets_idx(i - 2);
            ctr = ctr + 1;
        end
    end
end


%% closerInspection
function levelOnsets = closerInspection(levelOnsets, SR, ax)
% This function queries user input for each detected level onset to decide
% whether they are good as is or require manual work:
% Inputs:
% levelOnsets: detected onset of the levels
% SR: sampling rate of the signal
% ax: axes of the figure to zoom in and out
% Outputs
% levelOnsets: detected onset of the levels corrected by user manually

% Looping through all the detected onsets
for i = 1:length(levelOnsets)
    
    % plotting the nth detected level onset
    vertLine(i) = xline(levelOnsets(i), '--r', 'LineWidth', 1);
    
    % Zooming in for the experimenter to see better:
    set(ax,'Xlim',[levelOnsets(i)-SR*10 levelOnsets(i)+SR*20])
    
    % Asking user input:
    answer = questdlg('Is this level onset correct? Click on where you want to set it if not', ...
        'Manual detection', ...
        'Yes','No','No');
    
    
    % if the user says that this is not okay, give them the option to
    % adjust:
    if strcmp(answer, 'No')
        
        % Taking whatever the user set as the current level onset:
        [levelOnsets(i), ~] = ginput(1);
        
        % Removing the line:
        delete(vertLine(i))
        
        % Replacing by the new one:
        vertLine(i) = xline(levelOnsets(i), '--r', 'LineWidth', 1); % plotting the first detected level onset
        
        % Setting answer to initialze loop:
        answer = 'No';
        
        % Looping as long as the user isn't satisfied with the line:
        while strcmp(answer,'No')
            answer = questdlg('Is this level onset correct? Click on where you want to set it if not', ...
                'Manual detection', ...
                'Yes','No','No');
            % If the user says no, then reasking for input:
            if strcmp(answer,'No')
                % Taking whatever the user set as the current level onset:
                [levelOnsets(i), ~] = ginput(1);
                
                % Removing the line:
                delete(vertLine(i))
                
                % Replacing by the new one:
                vertLine(i) = xline(levelOnsets(i), '--r', 'LineWidth', 1); % plotting the first detected level onset
            end
        end
        
        % Removing the line:
        delete(vertLine(i))
    end
    
end

%% fullManualInputs
function levelOnsets = fullManualInputs(Append, levelOnsets)
% This function queries the user to enter the level onsets manually:
% Inputs:
%
% Outputs:
% - levelOnsets: Manually set levels onsets:

% Setting the prompt:
if Append
    prompt = {'To add new triggers, hover with the mouse over the specific points to get their x value. If you want to add several, separate them with commas'};
else
    prompt = {'The level begin detection didnt work. You will beed to select the level begins manually'};
end

dlgtitle = 'Input';
dims = [1 35];
definput = {'20','hsv'};
opts.WindowStyle = 'normal';
% Asking user input:
answer = inputdlg(prompt,dlgtitle,dims,definput,opts);

% Parsing the answer:
Str = sprintf('%s,', answer{:});

if Append
    levelOnsets = [levelOnsets, sscanf(Str, '%g,')'];
else
    levelOnsets = sscanf(Str, '%g,')';
end


%% checkDataIntegrity
function checkDataIntegrity(filePath, OriginalData, blockOnset_ind, blockOffset_ind)
% This function checks the integrity of the saved data by loading it and
% comparing it to the original:
% Inputs:
% - filePath: path to the file that was saved
% - OriginalData: data matrix containing the original data used to generate
% the subfiles
% - blockOnset_ind: index of the onset of the bock in the original data
% - blockOffset_ind: index of the offset of the block in the original data

% Loading the newly saved data:
[~, newSignal, ~] = loadData(filePath);

% Extracting only the part of the original data we are interested in:
OriginalDataBlock = OriginalData(:, blockOnset_ind:blockOffset_ind);

% Getting only the signal part of the new signal, getting rid of the
% headers:
newSignalRaw = newSignal.trial{1,1};

% Comparing the two matrices:
OriginalVSNew = OriginalDataBlock - newSignalRaw;

% if we have anything else than 0, we have a problem
if any(OriginalVSNew ~= 0)
    errorMessage = sprintf("There was an issue in the saving of the file %s! This file will be deleted! You will need to look at your signal closer to figure out what is going on!", filePath);
    delete(filePath)
    error(errorMessage)
end
    







