38 #ifndef __DataManager_hxx
39 #define __DataManager_hxx
43 #include "DataManager.h"
44 #include "HDF5Utils.h"
54 : m_representer(representer->Clone()) {
59 for (
typename DataItemListType::iterator it =
60 m_DataItemList.begin();
61 it != m_DataItemList.end(); ++it) {
64 m_DataItemList.clear();
76 const std::string& filename) {
83 file = H5File(filename.c_str(), H5F_ACC_RDONLY);
84 }
catch (H5::Exception& e) {
86 std::string(
"could not open HDF5 file \n") + e.getCDetailMsg());
93 Group representerGroup = file.openGroup(
"./representer");
94 std::string rep_name = HDF5Utils::readStringAttribute(representerGroup,
"name");
95 std::string repTypeStr = HDF5Utils::readStringAttribute(representerGroup,
"datasetType");
96 std::string versionStr = HDF5Utils::readStringAttribute(representerGroup,
"version");
97 typename RepresenterType::RepresenterDataType type = RepresenterType::TypeFromString(repTypeStr);
98 if (type == RepresenterType::CUSTOM || type == RepresenterType::UNKNOWN) {
99 if (rep_name != representer->GetName()) {
100 std::ostringstream os;
101 os <<
"A different representer was used to create the file and the representer is not of a standard type ";
102 os << (
"(RepresenterName = ") << rep_name <<
" does not match required name = " << representer->GetName() <<
")";
103 os <<
"Cannot load hdf5 file";
106 if (versionStr != representer->GetVersion()) {
107 std::ostringstream os;
108 os <<
"The version of the representers do not match ";
109 os << (
"(Version = ") << versionStr <<
" != = " << representer->GetVersion() <<
")";
110 os <<
"Cannot load hdf5 file";
115 if (type != representer->GetType()) {
116 std::ostringstream os;
117 os <<
"The representer that was provided cannot be used to load the dataset ";
118 os <<
"(" << type <<
" != " << representer->GetType() <<
").";
119 os <<
"Cannot load hdf5 file.";
123 representer->Load(representerGroup);
124 representerGroup.close();
128 Group publicGroup = file.openGroup(
"/data");
129 unsigned numds = HDF5Utils::readInt(publicGroup,
"./NumberOfDatasets");
131 for (
unsigned num = 0; num < numds; num++) {
132 std::ostringstream ss;
133 ss <<
"./dataset-" << num;
135 Group dsGroup = file.openGroup(ss.str().c_str());
136 newDataManager->m_DataItemList.push_back(
137 DataItemType::Load(representer, dsGroup));
141 }
catch (H5::Exception& e) {
144 "an exception occurred while reading data matrix to HDF5 file \n")
145 + e.getCDetailMsg());
151 assert(newDataManager != 0);
152 return newDataManager;
159 assert(m_representer != 0);
164 file = H5File(filename.c_str(), H5F_ACC_TRUNC);
165 }
catch (H5::Exception& e) {
167 std::string(
"Could not open HDF5 file for writing \n")
168 + e.getCDetailMsg());
174 Group representerGroup = file.createGroup(
"./representer");
175 std::string dataTypeStr = RepresenterType::TypeToString(m_representer->GetType());
177 HDF5Utils::writeStringAttribute(representerGroup,
"name", m_representer->GetName());
178 HDF5Utils::writeStringAttribute(representerGroup,
"version", m_representer->GetVersion());
179 HDF5Utils::writeStringAttribute(representerGroup,
"datasetType", dataTypeStr);
181 this->m_representer->Save(representerGroup);
182 representerGroup.close();
185 Group publicGroup = file.createGroup(
"./data");
186 HDF5Utils::writeInt(publicGroup,
"./NumberOfDatasets",
187 this->m_DataItemList.size());
190 for (
typename DataItemListType::const_iterator it =
191 this->m_DataItemList.begin();
192 it != this->m_DataItemList.end(); ++it) {
193 std::ostringstream ss;
194 ss <<
"./dataset-" << num;
196 Group dsGroup = file.createGroup(ss.str().c_str());
198 (*it)->Save(dsGroup);
203 }
catch (H5::Exception& e) {
206 "an exception occurred while writing data matrix to HDF5 file \n")
207 + e.getCDetailMsg());
215 const std::string& URI) {
217 DatasetPointerType sample;
218 sample = m_representer->CloneDataset(dataset);
220 m_DataItemList.push_back(
221 DataItemType::Create(m_representer, URI,
222 m_representer->SampleToSampleVector(sample)));
223 m_representer->DeleteDataset(sample);
228 return m_DataItemList;
233 unsigned nFolds,
bool randomize)
const {
234 if (nFolds <= 1 || nFolds > GetNumberOfSamples()) {
236 "Invalid number of folds specified in GetCrossValidationFolds");
238 unsigned nElemsPerFold = GetNumberOfSamples() / nFolds;
242 std::vector<unsigned> batchAssignment(GetNumberOfSamples());
244 for (
unsigned i = 0; i < GetNumberOfSamples(); i++) {
245 batchAssignment[i] = std::min(i / nElemsPerFold, nFolds);
251 std::random_shuffle(batchAssignment.begin(), batchAssignment.end());
255 CrossValidationFoldListType foldList;
256 for (
unsigned currentFold = 0; currentFold < nFolds; currentFold++) {
257 DataItemListType trainingData;
258 DataItemListType testingData;
260 unsigned sampleNum = 0;
261 for (
typename DataItemListType::const_iterator it =
262 m_DataItemList.begin();
263 it != m_DataItemList.end(); ++it) {
264 if (batchAssignment[sampleNum] != currentFold) {
265 trainingData.push_back(*it);
267 testingData.push_back(*it);
272 foldList.push_back(fold);
279 CrossValidationFoldListType foldList;
280 for (
unsigned currentFold = 0; currentFold < GetNumberOfSamples();
282 DataItemListType trainingData;
283 DataItemListType testingData;
285 unsigned sampleNum = 0;
286 for (
typename DataItemListType::const_iterator it =
287 m_DataItemList.begin();
288 it != m_DataItemList.end(); ++it, ++sampleNum) {
289 if (sampleNum == currentFold) {
290 testingData.push_back(*it);
292 trainingData.push_back(*it);
296 foldList.push_back(fold);
Manages Training and Test Data for building Statistical Models and provides functionality for Crossva...
Definition: DataManager.h:114
A trivial representer, that does no representation at all, but works directly with vectorial data...
Definition: TrivialVectorialRepresenter.h:83
Generic Exception class for the statismo Library.
Definition: Exceptions.h:68
Holds training and test data used for Crossvalidation.
Definition: DataManager.h:58
void Delete()
Definition: DataManager.h:148