private/som_subject_batch.m
916fe4f9
 function decode = som_subject_batch(header,subjectdata,somOpts)
48fb1827
 global NODALYZE_SOMTOOLBOX;
 addpath(fullfile(getTbxPath,NODALYZE_SOMTOOLBOX));
916fe4f9
 
a37f1788
 RANDOMIZE_DATAPOINTS = somOpts.rnd;
 NAN_AS_ZERO = somOpts.nantozero;
916fe4f9
 
 decode = struct;
 decode.decodePerformance = [];
 decode.rawTimeCourse     = [];
 
 nSubjects = numel(subjectdata);
 disp(sprintf('batch processing %u subjects',nSubjects));
 
 timeline = header.timeline;
 timeline.frameShiftStart = header.frameShift.frameShiftStart;
 timeline.frameShiftEnd   = header.frameShift.frameShiftEnd;
 timeline.decodeDuration  = header.frameShift.decodeDuration;
 
 
 timeLineStart   = timeline.frameShiftStart;
 timeLineEnd     = timeline.frameShiftEnd;
 
 % TimePointMatrix
 for subjectDataID = 1:nSubjects
     currentSubject = subjectdata{subjectDataID};
     timePointArgs.pst           = currentSubject.pst;
     timePointArgs.labelMap      = LabelMap(header.classDef.labelCells,header.classDef.conditionCells);
     timePointArgs.eventList     = header.classDef.eventMatrix;
 
     timePointMatrix{subjectDataID} = buildTimePointMatrix(timeline,timePointArgs);
 
     decode.rawTimeCourse = [decode.rawTimeCourse currentSubject.pst];
 
 
     display(sprintf('%u -fold cross validation for %u timeslices.\n',somOpts.nfold,size(1:timeLineEnd-timeLineStart+1,2)));
 
0c4a44fe
     decode_timeline = [];
916fe4f9
     for timeIndex = 1:timeLineEnd-timeLineStart+1
         svmstruct = calculateSVMTables(timePointMatrix{subjectDataID},timeIndex);
         nElements = length(svmstruct.svmlabel);
 
         if RANDOMIZE_DATAPOINTS
             rndindex  = randperm(nElements);
             svmstruct.svmdata   = svmstruct.svmdata(rndindex,:);
             svmstruct.svmlabel  = svmstruct.svmlabel(rndindex);
         end
 
         chunksize = nElements / somOpts.nfold;
         
         cross_value = 0;
         for iFold = 1:somOpts.nfold
             chunkstart = (iFold-1)*chunksize+1;
             chunkend   = min(iFold*chunksize,nElements);
 
             svm_train_label = svmstruct.svmlabel;
             svm_train_data  = svmstruct.svmdata;
 
             svm_validation_label = svmstruct.svmlabel(chunkstart:chunkend);
             svm_validation_data  = svmstruct.svmdata(chunkstart:chunkend,:);
 
             svm_train_label(chunkstart:chunkend) = []; %del test set
             svm_train_data(chunkstart:chunkend,:) = [];% del test set
 
a37f1788
             if NAN_AS_ZERO
                 svm_train_data(isnan(svm_train_data))=0;
                 svm_validation_data(isnan(svm_validation_data))=0;
                 display('NaN to 0');
             end
             
             if isempty(svm_train_data)
                 performance = 0;
             else
                 [sD sM] = som_train(svm_train_label, svm_train_data, somOpts);
                 performance = som_decode(sM, svm_validation_data,svm_validation_label);
             end
916fe4f9
 
             cross_value = [cross_value, performance];
         end
0c4a44fe
         decode_timeline = [decode_timeline; mean(cross_value)];
916fe4f9
     end
0c4a44fe
     decode.decodePerformance = [decode.decodePerformance decode_timeline];
916fe4f9
 end
 display('decode done');
 end