diff --git a/DirectXTex/DirectXTex.h b/DirectXTex/DirectXTex.h index 1e6c5ef..64078d3 100644 --- a/DirectXTex/DirectXTex.h +++ b/DirectXTex/DirectXTex.h @@ -396,7 +396,11 @@ namespace DirectX void *__cdecl GetBufferPointer() const noexcept { return m_buffer; } size_t __cdecl GetBufferSize() const noexcept { return m_size; } + HRESULT __cdecl Resize(size_t size) noexcept; + // Reallocate for a new size + HRESULT __cdecl Trim(size_t size) noexcept; + // Shorten size without reallocation private: void* m_buffer; diff --git a/DirectXTex/DirectXTexUtil.cpp b/DirectXTex/DirectXTexUtil.cpp index d7df359..68eda92 100644 --- a/DirectXTex/DirectXTexUtil.cpp +++ b/DirectXTex/DirectXTexUtil.cpp @@ -1510,3 +1510,25 @@ HRESULT Blob::Trim(size_t size) noexcept return S_OK; } + +HRESULT Blob::Resize(size_t size) noexcept +{ + if (!size) + return E_INVALIDARG; + + if (!m_buffer || !m_size) + return E_UNEXPECTED; + + void *tbuffer = _aligned_malloc(size, 16); + if (!tbuffer) + return E_OUTOFMEMORY; + + memcpy(tbuffer, m_buffer, std::min(m_size, size)); + + Release(); + + m_buffer = tbuffer; + m_size = size; + + return S_OK; +} diff --git a/DirectXTex/DirectXTexWIC.cpp b/DirectXTex/DirectXTexWIC.cpp index b77a9cc..94f92f3 100644 --- a/DirectXTex/DirectXTexWIC.cpp +++ b/DirectXTex/DirectXTexWIC.cpp @@ -11,89 +11,14 @@ #include "DirectXTexP.h" -//------------------------------------------------------------------------------------- -// IStream support for WIC Memory routines -//------------------------------------------------------------------------------------- - -#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) && (WINAPI_FAMILY != WINAPI_FAMILY_PHONE_APP) - - #include - #pragma comment(lib,"shcore.lib") - -#ifdef __cplusplus_winrt - - static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) - { - auto randomAccessStream = ref new ::Windows::Storage::Streams::InMemoryRandomAccessStream(); - return CreateStreamOverRandomAccessStream(randomAccessStream, IID_PPV_ARGS(stream)); - } - -#else - -#pragma warning(push) -#pragma warning(disable : 4619 5038) - #include - #include -#pragma warning(pop) - -#pragma warning(push) -#pragma warning(disable : 4471 5204) - #include -#pragma warning(pop) - - static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) - { - Microsoft::WRL::ComPtr abiStream; - HRESULT hr = Windows::Foundation::ActivateInstance( - Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), - abiStream.GetAddressOf()); - - if (SUCCEEDED(hr)) - { - hr = CreateStreamOverRandomAccessStream(abiStream.Get(), IID_PPV_ARGS(stream)); - } - return hr; - } - -#endif // __cplusplus_winrt - -#elif (!defined(WINAPI_FAMILY) || (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)) && (_WIN32_WINNT >= _WIN32_WINNT_WIN8) - -#include -#pragma comment(lib,"shlwapi.lib") - - static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) noexcept - { - if (!stream) - return E_INVALIDARG; - - *stream = SHCreateMemStream(nullptr, 0u); - if (!*stream) - return E_OUTOFMEMORY; - - return S_OK; - } - -#else - - #pragma prefast(suppress:6387 28196, "a simple wrapper around an existing annotated function" ); - static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) noexcept - { - return CreateStreamOnHGlobal(nullptr, TRUE, stream); - } - -#endif - using namespace DirectX; using Microsoft::WRL::ComPtr; namespace { - //------------------------------------------------------------------------------------- // WIC Pixel Format nearest conversion table //------------------------------------------------------------------------------------- - struct WICConvert { const GUID& source; @@ -266,6 +191,331 @@ namespace } + //------------------------------------------------------------------------------------- + // IStream over a Blob for WIC in-memory write functions + //------------------------------------------------------------------------------------- + class MemoryStreamOnBlob : public IStream + { + MemoryStreamOnBlob(Blob& blob) noexcept : + mBlob(blob), + m_streamPosition(0), + m_streamEOF(0), + mRefCount(1) + { + assert(mBlob.GetBufferPointer() && mBlob.GetBufferSize() > 0); + } + + public: + virtual ~MemoryStreamOnBlob() = default; + + MemoryStreamOnBlob(MemoryStreamOnBlob&&) = delete; + MemoryStreamOnBlob& operator= (MemoryStreamOnBlob&&) = delete; + + MemoryStreamOnBlob(MemoryStreamOnBlob const&) = delete; + MemoryStreamOnBlob& operator= (MemoryStreamOnBlob const&) = delete; + + // IUnknown + HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override + { + if (iid == __uuidof(IUnknown) + || iid == __uuidof(IStream) + || iid == __uuidof(ISequentialStream)) + { + *ppvObject = static_cast(this); + AddRef(); + return S_OK; + } + else + return E_NOINTERFACE; + } + + ULONG STDMETHODCALLTYPE AddRef() override + { + return InterlockedIncrement(&mRefCount); + } + + ULONG STDMETHODCALLTYPE Release() override + { + ULONG res = InterlockedDecrement(&mRefCount); + if (res == 0) + { + delete this; + } + return res; + } + + // ISequentialStream + HRESULT STDMETHODCALLTYPE Read(void* pv, ULONG cb, ULONG* pcbRead) override + { + size_t maxRead = m_streamEOF - m_streamPosition; + auto ptr = static_cast(mBlob.GetBufferPointer()); + if (cb > maxRead) + { + uint64_t pos = uint64_t(m_streamPosition) + uint64_t(maxRead); + if (pos > UINT32_MAX) + return HRESULT_E_ARITHMETIC_OVERFLOW; + + memcpy(pv, &ptr[m_streamPosition], maxRead); + + m_streamPosition = static_cast(pos); + + if (pcbRead) + { + *pcbRead = static_cast(maxRead); + } + return E_BOUNDS; + } + else + { + uint64_t pos = uint64_t(m_streamPosition) + uint64_t(cb); + if (pos > UINT32_MAX) + return HRESULT_E_ARITHMETIC_OVERFLOW; + + memcpy(pv, &ptr[m_streamPosition], cb); + + m_streamPosition = static_cast(pos); + + if (pcbRead) + { + *pcbRead = cb; + } + return S_OK; + } + } + + HRESULT STDMETHODCALLTYPE Write(void const* pv, ULONG cb, ULONG* pcbWritten) override + { + size_t blobSize = mBlob.GetBufferSize(); + size_t spaceAvailable = blobSize - m_streamPosition; + size_t growAmount = cb; + + if (spaceAvailable > 0) + { + if (spaceAvailable >= growAmount) + { + growAmount = 0; + } + else + { + growAmount -= spaceAvailable; + } + } + + if (growAmount > 0) + { + uint64_t newSize = uint64_t(blobSize); + uint64_t targetSize = uint64_t(blobSize) + growAmount; + HRESULT hr = ComputeGrowSize(newSize, targetSize); + if (FAILED(hr)) + return hr; + + hr = mBlob.Resize(static_cast(newSize)); + if (FAILED(hr)) + return hr; + } + + uint64_t pos = uint64_t(m_streamPosition) + uint64_t(cb); + if (pos > UINT32_MAX) + return HRESULT_E_ARITHMETIC_OVERFLOW; + + auto ptr = static_cast(mBlob.GetBufferPointer()); + memcpy(&ptr[m_streamPosition], pv, cb); + + m_streamPosition = static_cast(pos); + m_streamEOF = std::max(m_streamEOF, m_streamPosition); + + if (pcbWritten) + { + *pcbWritten = cb; + } + return S_OK; + } + + // IStream + HRESULT STDMETHODCALLTYPE SetSize(ULARGE_INTEGER size) override + { + if (size.HighPart > 0) + return E_OUTOFMEMORY; + + size_t blobSize = mBlob.GetBufferSize(); + + if (blobSize >= size.LowPart) + { + auto ptr = static_cast(mBlob.GetBufferPointer()); + if (m_streamEOF < size.LowPart) + { + memset(&ptr[m_streamEOF], 0, size.LowPart - m_streamEOF); + } + + m_streamEOF = static_cast(size.LowPart); + } + else + { + uint64_t newSize = uint64_t(blobSize); + uint64_t targetSize = uint64_t(size.QuadPart); + HRESULT hr = ComputeGrowSize(newSize, targetSize); + if (FAILED(hr)) + return hr; + + hr = mBlob.Resize(static_cast(newSize)); + if (FAILED(hr)) + return hr; + + blobSize = mBlob.GetBufferSize(); + + auto ptr = static_cast(mBlob.GetBufferPointer()); + if (m_streamEOF < size.LowPart) + { + memset(&ptr[m_streamEOF], 0, size.LowPart - m_streamEOF); + } + + m_streamEOF = static_cast(size.LowPart); + } + + if (m_streamPosition > m_streamEOF) + { + m_streamPosition = m_streamEOF; + } + + return S_OK; + } + + HRESULT STDMETHODCALLTYPE CopyTo(IStream*, ULARGE_INTEGER, ULARGE_INTEGER*, ULARGE_INTEGER*) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE Commit(DWORD) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE Revert(void) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE LockRegion(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE UnlockRegion(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE Clone(IStream**) override + { + return E_NOTIMPL; + } + + HRESULT STDMETHODCALLTYPE Seek(LARGE_INTEGER liDistanceToMove, DWORD dwOrigin, ULARGE_INTEGER* lpNewFilePointer) override + { + LONGLONG newPosition = 0; + + switch (dwOrigin) + { + case STREAM_SEEK_SET: + newPosition = liDistanceToMove.QuadPart; + break; + + case STREAM_SEEK_CUR: + newPosition = static_cast(m_streamPosition) + liDistanceToMove.QuadPart; + break; + + case STREAM_SEEK_END: + newPosition = static_cast(m_streamEOF) + liDistanceToMove.QuadPart; + break; + + default: + return STG_E_INVALIDFUNCTION; + } + + HRESULT result = S_OK; + + if (newPosition > static_cast(m_streamEOF)) + { + m_streamPosition = m_streamEOF; + result = E_BOUNDS; + } + else if (newPosition < 0) + { + m_streamPosition = 0; + result = E_BOUNDS; + } + else + { + m_streamPosition = static_cast(newPosition); + } + + if (lpNewFilePointer) + { + lpNewFilePointer->QuadPart = static_cast(m_streamPosition); + } + + return result; + } + + HRESULT STDMETHODCALLTYPE Stat(STATSTG* pStatstg, DWORD) override + { + if (!pStatstg) + return E_INVALIDARG; + pStatstg->cbSize.QuadPart = static_cast(m_streamEOF); + return S_OK; + } + + HRESULT Finialize() noexcept + { + if (mRefCount > 1) + return E_FAIL; + + return mBlob.Trim(m_streamEOF); + } + + static HRESULT CreateMemoryStream(_Outptr_ MemoryStreamOnBlob** stream, Blob& blob) noexcept + { + if (!stream) + return E_INVALIDARG; + + *stream = nullptr; + + auto ptr = new (std::nothrow) MemoryStreamOnBlob(blob); + if (!ptr) + return E_OUTOFMEMORY; + + *stream = ptr; + + return S_OK; + } + + private: + Blob& mBlob; + size_t m_streamPosition; + size_t m_streamEOF; + ULONG mRefCount; + + static HRESULT ComputeGrowSize(uint64_t& newSize, uint64_t& targetSize) noexcept + { + // We grow by doubling until we hit 256MB, then we add 16MB at a time. + while (newSize < targetSize) + { + if (newSize < (256 * 1024 * 1024)) + { + newSize <<= 1; + } + else + { + newSize += 16 * 1024 * 1024; + } + if (newSize > UINT32_MAX) + return E_OUTOFMEMORY; + } + + return S_OK; + } + }; + //------------------------------------------------------------------------------------- // Determines metadata for image //------------------------------------------------------------------------------------- @@ -1198,42 +1448,33 @@ HRESULT DirectX::SaveToWICMemory( if (!image.pixels) return E_POINTER; - blob.Release(); - - ComPtr stream; - HRESULT hr = CreateMemoryStream(stream.GetAddressOf()); + HRESULT hr = blob.Initialize(65535u); if (FAILED(hr)) return hr; + ComPtr stream; + hr = MemoryStreamOnBlob::CreateMemoryStream(&stream, blob); + if (FAILED(hr)) + { + blob.Release(); + return hr; + } + hr = EncodeSingleFrame(image, flags, containerFormat, stream.Get(), targetFormat, setCustomProps); if (FAILED(hr)) + { + blob.Release(); return hr; + } - // Copy stream data into blob - STATSTG stat; - hr = stream->Stat(&stat, STATFLAG_NONAME); + hr = stream->Finialize(); if (FAILED(hr)) + { + blob.Release(); return hr; + } - if (stat.cbSize.HighPart > 0) - return HRESULT_E_FILE_TOO_LARGE; - - hr = blob.Initialize(stat.cbSize.LowPart); - if (FAILED(hr)) - return hr; - - LARGE_INTEGER li = {}; - hr = stream->Seek(li, STREAM_SEEK_SET, nullptr); - if (FAILED(hr)) - return hr; - - DWORD bytesRead; - hr = stream->Read(blob.GetBufferPointer(), static_cast(blob.GetBufferSize()), &bytesRead); - if (FAILED(hr)) - return hr; - - if (bytesRead != blob.GetBufferSize()) - return E_FAIL; + stream.Reset(); return S_OK; } @@ -1251,46 +1492,37 @@ HRESULT DirectX::SaveToWICMemory( if (!images || nimages == 0) return E_INVALIDARG; - blob.Release(); - - ComPtr stream; - HRESULT hr = CreateMemoryStream(stream.GetAddressOf()); + HRESULT hr = blob.Initialize(65535u); if (FAILED(hr)) return hr; + ComPtr stream; + hr = MemoryStreamOnBlob::CreateMemoryStream(&stream, blob); + if (FAILED(hr)) + { + blob.Release(); + return hr; + } + if (nimages > 1) hr = EncodeMultiframe(images, nimages, flags, containerFormat, stream.Get(), targetFormat, setCustomProps); else hr = EncodeSingleFrame(images[0], flags, containerFormat, stream.Get(), targetFormat, setCustomProps); if (FAILED(hr)) + { + blob.Release(); return hr; + } - // Copy stream data into blob - STATSTG stat; - hr = stream->Stat(&stat, STATFLAG_NONAME); + hr = stream->Finialize(); if (FAILED(hr)) + { + blob.Release(); return hr; + } - if (stat.cbSize.HighPart > 0) - return HRESULT_E_FILE_TOO_LARGE; - - hr = blob.Initialize(stat.cbSize.LowPart); - if (FAILED(hr)) - return hr; - - LARGE_INTEGER li = {}; - hr = stream->Seek(li, STREAM_SEEK_SET, nullptr); - if (FAILED(hr)) - return hr; - - DWORD bytesRead; - hr = stream->Read(blob.GetBufferPointer(), static_cast(blob.GetBufferSize()), &bytesRead); - if (FAILED(hr)) - return hr; - - if (bytesRead != blob.GetBufferSize()) - return E_FAIL; + stream.Reset(); return S_OK; }