00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00024
00025 #ifndef QVTENSOR_H
00026 #define QVTENSOR_H
00027
00028 #include <cblas.h>
00029 #include <iostream>
00030
00031 #include <QSharedData>
00032 #include <qtensor/qtensorindexator.h>
00033
00034 #define MIN(X,Y) (((X)<(Y))?(X):(Y))
00035 #define MAX(X,Y) (((X)>(Y))?(X):(Y))
00036
00163 class QTensor
00164 {
00165 public:
00168 QTensor(const QTensor &tensor): dataSize(tensor.dataSize), dims(tensor.dims), indexIds(tensor.indexIds), data(tensor.data) { }
00169
00176 QTensor(const QTensorValence &indexList):
00177 dataSize(1), dims(indexList.size()), indexIds(indexList.size())
00178 {
00179 for (int n = 0; n < indexList.size(); n++)
00180 {
00181 indexIds[n] = indexList.at(n).id;
00182 dataSize *= (dims[n] = indexList.at(n).dim);
00183 }
00184 data = new QTensorSharedData(dataSize);
00185 }
00186
00191 QTensor &operator=(const QTensor &tensor) { return copy(tensor); };
00192
00195 bool operator==(const QTensor &tensor) const { return equals(tensor); };
00196
00199 QTensor operator*(const QTensor &tensor) const { return tensorProduct(tensor); };
00200
00203 QTensor operator^(const QTensor &tensor) const { return innerProduct(tensor); };
00204
00209 QTensor operator()(const QTensorValence &indexList) const { return renameIndexes(indexList); };
00210
00215 const int getDataSize() const { return dataSize; }
00216
00219 const double *getReadData() const { return data->getReadData(); }
00220
00223 double *getWriteData() { return data->getWriteData(); }
00224
00227 QTensorValence getValence() const;
00228
00235 QTensor slice(const QTensorIndexValues &indexRangeList) const;
00236
00242 QTensor transpose(const QTensorValence &indexList) const;
00243
00248 QTensor transpose(const QTensorIndex &i, const QTensorIndex &j) const;
00249
00254 QTensor outerProduct(const QTensor &tensor) const;
00255
00259 QTensor tensorProduct(const QTensor &tensor) const;
00260
00264 QTensor innerProduct(const QTensor &tensor) const;
00265
00269 bool equals(const QTensor &tensor) const;
00270 QTensor renameIndexes(const QTensorValence &indexList) const;
00271
00277 QTensor ©(const QTensor &tensor);
00278
00279 private:
00280 friend std::ostream& operator << ( std::ostream &os, const QTensor &tensor );
00281 int dataSize;
00282 QVector <int> dims;
00283 QVector <int> indexIds;
00284
00285
00286 QTensor transpose(const QVector<int> &sorting) const;
00287 QTensor transpose(const int index1Position, const int index2Position) const;
00288 QTensor contract() const;
00289
00290
00291 class QTensorSharedData: public QSharedData
00292 {
00293 public:
00294 QTensorSharedData(const int size): QSharedData(), dataSize(size), data(new double[dataSize])
00295 { }
00296
00297 QTensorSharedData(const QTensorSharedData &tensorData): QSharedData(), dataSize(tensorData.dataSize), data(new double[100*dataSize])
00298 { cblas_dcopy(dataSize, tensorData.getReadData(), 1, getWriteData(), 1); }
00299
00300 ~QTensorSharedData()
00301 { delete data; }
00302
00303 const double *getReadData() const { return data; }
00304 double *getWriteData() { return data; }
00305
00306 private:
00307 int dataSize;
00308 double *data;
00309 };
00310
00311 QSharedDataPointer< QTensorSharedData > data;
00312 };
00313
00314 std::ostream& operator << ( std::ostream &os, const QTensor &tensor );
00315
00316 #endif