Statismo  0.10.1
 All Classes Namespaces Functions Typedefs
DataManager.hxx
1 /*
2  * This file is part of the statismo library.
3  *
4  * Author: Marcel Luethi (marcel.luethi@unibas.ch)
5  *
6  * Copyright (c) 2011 University of Basel
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * Redistributions of source code must retain the above copyright notice,
14  * this list of conditions and the following disclaimer.
15  *
16  * Redistributions in binary form must reproduce the above copyright
17  * notice, this list of conditions and the following disclaimer in the
18  * documentation and/or other materials provided with the distribution.
19  *
20  * Neither the name of the project's author nor the names of its
21  * contributors may be used to endorse or promote products derived from
22  * this software without specific prior written permission.
23  *
24  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
25  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
26  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
27  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
28  * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
29  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
30  * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
31  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
32  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
33  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
34  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #ifndef __DataManager_hxx
39 #define __DataManager_hxx
40 
41 #include <iostream>
42 
43 #include "DataManager.h"
44 #include "HDF5Utils.h"
45 
46 namespace statismo {
47 
49 // Data manager
51 
52 template<typename T>
53 DataManager<T>::DataManager(const RepresenterType* representer)
54  : m_representer(representer->Clone()) {
55 }
56 
57 template<typename T>
59  for (typename DataItemListType::iterator it =
60  m_DataItemList.begin();
61  it != m_DataItemList.end(); ++it) {
62  delete (*it);
63  }
64  m_DataItemList.clear();
65  if (m_representer) {
66  m_representer->Delete();
67  }
68 
69 }
70 
71 
72 
73 template<typename T>
75 DataManager<T>::Load(Representer<T>* representer,
76  const std::string& filename) {
77  using namespace H5;
78 
79  DataManager<T>* newDataManager = 0;
80 
81  H5File file;
82  try {
83  file = H5File(filename.c_str(), H5F_ACC_RDONLY);
84  } catch (H5::Exception& e) {
85  std::string msg(
86  std::string("could not open HDF5 file \n") + e.getCDetailMsg());
87  throw StatisticalModelException(msg.c_str());
88  }
89 
90  try {
91  // loading representer
92 
93  Group representerGroup = file.openGroup("./representer");
94  std::string rep_name = HDF5Utils::readStringAttribute(representerGroup, "name");
95  std::string repTypeStr = HDF5Utils::readStringAttribute(representerGroup, "datasetType");
96  std::string versionStr = HDF5Utils::readStringAttribute(representerGroup, "version");
97  typename RepresenterType::RepresenterDataType type = RepresenterType::TypeFromString(repTypeStr);
98  if (type == RepresenterType::CUSTOM || type == RepresenterType::UNKNOWN) {
99  if (rep_name != representer->GetName()) {
100  std::ostringstream os;
101  os << "A different representer was used to create the file and the representer is not of a standard type ";
102  os << ("(RepresenterName = ") << rep_name << " does not match required name = " << representer->GetName() << ")";
103  os << "Cannot load hdf5 file";
104  throw StatisticalModelException(os.str().c_str());
105  }
106  if (versionStr != representer->GetVersion()) {
107  std::ostringstream os;
108  os << "The version of the representers do not match ";
109  os << ("(Version = ") << versionStr << " != = " << representer->GetVersion() << ")";
110  os << "Cannot load hdf5 file";
111 
112  }
113 
114  }
115  if (type != representer->GetType()) {
116  std::ostringstream os;
117  os << "The representer that was provided cannot be used to load the dataset ";
118  os << "(" << type << " != " << representer->GetType() << ").";
119  os << "Cannot load hdf5 file.";
120  throw StatisticalModelException(os.str().c_str());
121  }
122 
123  representer->Load(representerGroup);
124  representerGroup.close();
125  newDataManager = new DataManager<T>(representer);
126 
127 
128  Group publicGroup = file.openGroup("/data");
129  unsigned numds = HDF5Utils::readInt(publicGroup, "./NumberOfDatasets");
130 
131  for (unsigned num = 0; num < numds; num++) {
132  std::ostringstream ss;
133  ss << "./dataset-" << num;
134 
135  Group dsGroup = file.openGroup(ss.str().c_str());
136  newDataManager->m_DataItemList.push_back(
137  DataItemType::Load(representer, dsGroup));
138 
139  }
140 
141  } catch (H5::Exception& e) {
142  std::string msg(
143  std::string(
144  "an exception occurred while reading data matrix to HDF5 file \n")
145  + e.getCDetailMsg());
146  throw StatisticalModelException(msg.c_str());
147  }
148 
149  file.close();
150 
151  assert(newDataManager != 0);
152  return newDataManager;
153 }
154 
155 template<typename T>
156 void DataManager<T>::Save(const std::string& filename) const {
157  using namespace H5;
158 
159  assert(m_representer != 0);
160 
161  H5File file;
162 
163  try {
164  file = H5File(filename.c_str(), H5F_ACC_TRUNC);
165  } catch (H5::Exception& e) {
166  std::string msg(
167  std::string("Could not open HDF5 file for writing \n")
168  + e.getCDetailMsg());
169  throw StatisticalModelException(msg.c_str());
170  }
171 
172  try {
173 
174  Group representerGroup = file.createGroup("./representer");
175  std::string dataTypeStr = RepresenterType::TypeToString(m_representer->GetType());
176 
177  HDF5Utils::writeStringAttribute(representerGroup, "name", m_representer->GetName());
178  HDF5Utils::writeStringAttribute(representerGroup, "version", m_representer->GetVersion());
179  HDF5Utils::writeStringAttribute(representerGroup, "datasetType", dataTypeStr);
180 
181  this->m_representer->Save(representerGroup);
182  representerGroup.close();
183 
184 
185  Group publicGroup = file.createGroup("./data");
186  HDF5Utils::writeInt(publicGroup, "./NumberOfDatasets",
187  this->m_DataItemList.size());
188 
189  unsigned num = 0;
190  for (typename DataItemListType::const_iterator it =
191  this->m_DataItemList.begin();
192  it != this->m_DataItemList.end(); ++it) {
193  std::ostringstream ss;
194  ss << "./dataset-" << num;
195 
196  Group dsGroup = file.createGroup(ss.str().c_str());
197 
198  (*it)->Save(dsGroup);
199 
200  dsGroup.close();
201  num++;
202  }
203  } catch (H5::Exception& e) {
204  std::string msg(
205  std::string(
206  "an exception occurred while writing data matrix to HDF5 file \n")
207  + e.getCDetailMsg());
208  throw StatisticalModelException(msg.c_str());
209  }
210  file.close();
211 }
212 
213 template<typename T>
214 void DataManager<T>::AddDataset(DatasetConstPointerType dataset,
215  const std::string& URI) {
216 
217  DatasetPointerType sample;
218  sample = m_representer->CloneDataset(dataset);
219 
220  m_DataItemList.push_back(
221  DataItemType::Create(m_representer, URI,
222  m_representer->SampleToSampleVector(sample)));
223  m_representer->DeleteDataset(sample);
224 }
225 
226 template<typename T>
227 typename DataManager<T>::DataItemListType DataManager<T>::GetData() const {
228  return m_DataItemList;
229 }
230 
231 template<typename T>
232 typename DataManager<T>::CrossValidationFoldListType DataManager<T>::GetCrossValidationFolds(
233  unsigned nFolds, bool randomize) const {
234  if (nFolds <= 1 || nFolds > GetNumberOfSamples()) {
236  "Invalid number of folds specified in GetCrossValidationFolds");
237  }
238  unsigned nElemsPerFold = GetNumberOfSamples() / nFolds;
239 
240  // we create a vector with as many entries as datasets. Each entry contains the
241  // fold the entry belongs to
242  std::vector<unsigned> batchAssignment(GetNumberOfSamples());
243 
244  for (unsigned i = 0; i < GetNumberOfSamples(); i++) {
245  batchAssignment[i] = std::min(i / nElemsPerFold, nFolds);
246  }
247 
248  // randomly shuffle the vector
249  srand(time(0));
250  if (randomize) {
251  std::random_shuffle(batchAssignment.begin(), batchAssignment.end());
252  }
253 
254  // now we create the folds
255  CrossValidationFoldListType foldList;
256  for (unsigned currentFold = 0; currentFold < nFolds; currentFold++) {
257  DataItemListType trainingData;
258  DataItemListType testingData;
259 
260  unsigned sampleNum = 0;
261  for (typename DataItemListType::const_iterator it =
262  m_DataItemList.begin();
263  it != m_DataItemList.end(); ++it) {
264  if (batchAssignment[sampleNum] != currentFold) {
265  trainingData.push_back(*it);
266  } else {
267  testingData.push_back(*it);
268  }
269  ++sampleNum;
270  }
271  CrossValidationFoldType fold(trainingData, testingData);
272  foldList.push_back(fold);
273  }
274  return foldList;
275 }
276 
277 template<typename T>
278 typename DataManager<T>::CrossValidationFoldListType DataManager<T>::GetLeaveOneOutCrossValidationFolds() const {
279  CrossValidationFoldListType foldList;
280  for (unsigned currentFold = 0; currentFold < GetNumberOfSamples();
281  currentFold++) {
282  DataItemListType trainingData;
283  DataItemListType testingData;
284 
285  unsigned sampleNum = 0;
286  for (typename DataItemListType::const_iterator it =
287  m_DataItemList.begin();
288  it != m_DataItemList.end(); ++it, ++sampleNum) {
289  if (sampleNum == currentFold) {
290  testingData.push_back(*it);
291  } else {
292  trainingData.push_back(*it);
293  }
294  }
295  CrossValidationFoldType fold(trainingData, testingData);
296  foldList.push_back(fold);
297  }
298  return foldList;
299 }
300 
301 } // Namespace statismo
302 
303 #endif
Manages Training and Test Data for building Statistical Models and provides functionality for Crossva...
Definition: DataManager.h:114
A trivial representer, that does no representation at all, but works directly with vectorial data...
Definition: TrivialVectorialRepresenter.h:83
Generic Exception class for the statismo Library.
Definition: Exceptions.h:68
Holds training and test data used for Crossvalidation.
Definition: DataManager.h:58
void Delete()
Definition: DataManager.h:148