winsockを使ってみる。(1)

boost::asioを使いたかったのですが、ネットワークの知識が乏しいので、まずはwinsockの勉強をしようと思います。

参考サイト

とりあえず以下のように使えるSocketクラス、SocketStreamクラスを作ってみます。
開発環境はwindows7VC++2008EEです。XP、Vistaでも問題無いはず。

Server
  • 接続してきたクライアントに文字列を送り返す。
  • SocketとStreamはsmart_ptrで。
#include <net/socket.h>

int main()
{
    try
    {
        using namespace net::tcp;
        //winsock初期化
        net::startup();
        //指定ポートでサーバ用TCPソケットを作成
        ServerSocketPtr socket = IServerSocket::create( 12345 );
        //クライアントの接続待ち
        SocketPtr client = socket->accept();
        //ソケットからストリームを取得
        SocketStreamWeakPtr stream = client->get_stream();
        //適当な文字列を送る
        stream->write_line( "hoge" );
    }
    catch( net::socket_error& e )
    {
        std::cout << e.what() << std::endl;
    }

    //winsock破棄
    net::cleanup();

    return 0;
}
Client
  • サーバに接続して文字列を受信
  • SocketとStreamはsmart_ptrで。
#include <net/socket.h>

int main()
{
    try
    {
        using namespace net::tcp;
        //winsock初期化
        net::startup();
        //クライアント用TCPソケットを作成
        ClientSocketPtr socket = IClientSocket::create();
        //指定ポート、IPアドレスのサーバに接続
        socket->connect( 12345, "localhost" );
        //ソケットからストリームを取得
        SocketStreamWeakPtr stream = socket->get_stream();
        //1行分の文字列を受信
        std::string str;
        stream->read_line( str );
        std::cout << str;
    }
    catch( net::socket_error& e )
    {
        std::cout << e.what() << std::endl;
    }

    //winsock破棄
    net::cleanup();
    return 0;
}

プロジェクトの構成

winsockの初期化と破棄

  • socket_error
    • WSAGetLastErrorの戻り値を文字列に変換
  • startup
    • winsockの初期化
  • cleanup
    • winsockの破棄
socket.h
#include <winsock2.h>
#pragma comment( lib, "wsock32.lib" )

#include <string>
#include <sstream>
#include <iostream>
#include <exception>

namespace net
{
    /**
     *  ソケット例外クラス
     */
    class socket_error
        : public std::exception
    {
    public:
        socket_error()
        {
            //エラーコードを文字列化
            void* msg;
            FormatMessage(
                FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
                0, WSAGetLastError(), MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ),
                ( char* )&msg, 0, 0
                );
            str = static_cast< const char* >( msg );
            LocalFree( msg );
        }
        socket_error( const std::string& str )
            : str( str )
        {
        }
    public:
        const char * __CLR_OR_THIS_CALL what() const
        {
            return str.c_str();
        }

    private:
        std::string str;
    };
}

namespace net
{
    /**
     *  winsockの初期化
     */
    void startup();
    /**
     *  winsockの破棄
     */
    void cleanup();
}
socket.cpp
namespace net
{
    /**
     *  WinSock初期化
     */
    void startup()
    {
        WORD version = MAKEWORD( 2, 0 );
        WSADATA data = { 0 };
        int err = WSAStartup( version, &data );
        //初期化失敗
        if( err != 0 )
            throw socket_error();
        //指定したバージョンと異なる
        if( data.wVersion != version )
            throw socket_error( "指定したバージョンと異なる" );

        //WSADATAの確認
        std::stringstream ss;
        ss  << data.wVersion << std::endl
            << data.iMaxSockets << std::endl
            << data.iMaxUdpDg << std::endl
            //<< data.lpVendorInfo << std::endl
            << data.szDescription << std::endl
            << data.szSystemStatus << std::endl
            << data.wHighVersion << std::endl
            << data.wVersion << std::endl
            ;

        std::cout << ss.str();
    }

    /**
     *  WinSock破棄
     */
    void cleanup()
    {
        WSACleanup();
    }
}

参照カウンタクラス

ソケットクラスのインスタンスは、侵入型参照カウンタとintrusive_ptrで管理してみます。
参照カウンタの実装はこちらを参考にしました。

socket.h
namespace net
{
    /**
     *  intrusive_ptr用参照カウンタ
     */
    class RefCounter
    {
    public:
        RefCounter& operator++(){ ++ count; return *this; }
        RefCounter& operator--(){ -- count; return *this; }
        operator int()const{ return count; }
    public:
        RefCounter() : count( 0 ){}
    private:
        int count;
    };
}

/**
 *  参照カウンタ用インターフェイスの宣言
 */
#define declare_ref_counter( name ) \
private: \
    virtual void add_ref() = 0; \
    virtual void release() = 0; \
    friend void intrusive_ptr_add_ref( name* p ){ p->add_ref(); } \
    friend void intrusive_ptr_release( name* p ){ p->release(); }

/**
 *  参照カウンタ用インターフェイスの実装
 */
#define implement_ref_counter( name ) \
private: \
    RefCounter ref_counter; \
private: \
    void add_ref(){ ++ ref_counter; } \
    void release(){ if( -- ref_counter == 0 )delete this; }

Socketインターフェイス

  • ISocket
    • add_ref
    • release
socket.h
namespace net
{
    namespace tcp
    {
        /**
         *  Socketインターフェイス
         */
        class ISocket
        {
            declare_ref_counter( ISocket );

        protected:
            virtual ~ISocket(){}
        };
    }
}
  • SocketBase
    • Socketを作成する
    • 既存のSOCKET、sockaddr_inからSocketを作成する
    • Socketを閉じる
socket.cpp
namespace net
{
    namespace tcp
    {
        /**
         *  TCP Socket
         */
        class SocketBase
        {
        public:
            void close()
            {
                if( s != INVALID_SOCKET )
                {
                    ::shutdown( s, SD_BOTH );
                    ::closesocket( s );

                    s = INVALID_SOCKET;
                }
            }

        public:
            SocketBase()
                : s( ::socket( AF_INET, SOCK_STREAM, IPPROTO_TCP ) )
            {
                memset( &sa, 0, sizeof( sa ) );
            }

            SocketBase( SOCKET s, const sockaddr_in& sa )
                : s( s )
                , sa( sa )
            {
            }

            virtual ~SocketBase()
            {
                close();
            }

        protected:
            SOCKET s;
            sockaddr_in sa;
        };
    }
}
  • SocketImpl
    • 参照カウンタの実装
namespace net
{
    namespace tcp
    {
        /**
         *  TCP Socketの実装クラス
         */
        class SocketImpl
            : public SocketBase
            , public ISocket
        {
            implement_ref_counter( SocketImpl );

        public:
            SocketImpl()
                : SocketBase()
            {
            }

            SocketImpl( SOCKET s, const sockaddr_in& sa )
                : SocketBase( s, sa  )
            {
            }
        };
    }
}

SocketStreamインターフェイス

  • ISocketStream
    • add_ref
    • release
    • write_line
      • 1行送信
    • read_line
      • 1行受信
socket.h
namespace net
{
    namespace tcp
    {
        /**
         *  SocketStreamインターフェイス
         */
        class ISocketStream
        {
        public:
            virtual void write_line( const std::string& str ) = 0;
            virtual int read_line( std::string& str ) = 0;

        protected:
            virtual ~ISocketStream(){}
        };
    }
}
  • SocketStreamImpl
    • add_ref
    • release
    • write_line
      • 引数で受けっとった文字列に改行コードを付加して送信
    • read_line
      • 改行コードまで1文字ずつ受信
socket.cpp
namespace net
{
    namespace tcp
    {
        /**
         *  SocketStream実装
         */
        class SocketStreamImpl
            : public ISocketStream
        {
        public:
            void write_line( const std::string& str )
            {
                std::string s = str + "\n";
                size_t size = s.size();
                ::send( socket->s, s.c_str(), size, 0 );
            }

            int read_line( std::string& str )
            {
                int ret = 0;
                //改行コードまで1文字ずつ受信
                while( bool b = true )
                {
                    char c;
                    ret = ::recv( socket->s, &c, 1, 0 );
                    if( ret == 0 || ret == SOCKET_ERROR )
                    {
                        //切断されたorエラー
                        break;
                    }
                    else
                    {
                        str += c;
                        if( c == '\n' )
                            break;
                    }
                }

                return ret;
            }

        public:
            SocketStreamImpl( SocketBase* socket )
                : socket( socket )
            {}

        private:
            SocketBase* socket;
        };
    }
}

SocketインターフェイスからSocketStreamインターフェイスを取得する

SocketクラスがSocketStreamクラスを所有し、weak_ptrで使用します。

  • ISocket
    • get_streamを追加
socket.h
namespace net
{
    namespace tcp
    {
        class ISocket;
        typedef hoge::intrusive_ptr< ISocket > SocketPtr;
        class ISocketStream;
        typedef hoge::shared_ptr< ISocketStream > SocketStreamPtr;
        typedef hoge::weak_ptr< ISocketStream > SocketStreamWeakPtr;
    }
}

namespace net
{
    namespace tcp
    {
        /**
         *  Socketインターフェイス
         */
        class ISocket
        {
        ..略..
        public:
            virtual SocketStreamWeakPtr get_stream() = 0;
        };
    }
}
socket.cpp
namespace net
{
    namespace tcp
    {
        /**
         *  TCP Socketの実装クラス
         */
        class SocketImpl
            : public SocketBase
            , public ISocket
        {
        ..略..
        public:
            SocketStreamWeakPtr get_stream()
            {
                return stream;
            }
        public:
            SocketImpl()
                : SocketBase()
            {
                stream = SocketStreamPtr( new SocketStreamImpl( this ) );
            }

            SocketImpl( SOCKET s, const sockaddr_in& sa )
                : SocketBase( s, sa  )
            {
                stream = SocketStreamPtr( new SocketStreamImpl( this ) );
            }

        private:
            SocketStreamPtr stream;
        };
    }
}

ServerSocketインターフェイス

  • IServerSocket
    • add_ref
    • release
    • bind
    • accept
socket.h
namespace net
{
    namespace tcp
    {
        class IServerSocket;
        typedef hoge::intrusive_ptr< IServerSocket > ServerSocketPtr;
    }
}

namespace net
{
    namespace tcp
    {
        /**
         *  ServerSocketインターフェイス
         */
        class IServerSocket
        {
            declare_ref_counter( IServerSocket );
        public:
            virtual void bind( WORD port ) = 0;
            virtual SocketPtr accept() = 0;

        protected:
            virtual ~IServerSocket(){}

        public:
            static ServerSocketPtr create( WORD port );
        };
    }
}
  • ServerSocketImpl
    • add_ref
    • release
    • bind
      • ポートとソケットのバインド
    • accept
      • 接続を受け入れたクライアントのSocketImplを生成して返す
    • create
      • 指定ポートにバインドしたサーバソケットを生成して返す
socket.cpp
namespace net
{
    namespace tcp
    {
        /**
         *  ServerSocket実装
         */
        class ServerSocketImpl
            : public SocketBase
            , public IServerSocket
        {
            implement_ref_counter( ServerSocketImpl );
        public:
            void bind( WORD port )
            {
                sa.sin_family = AF_INET;
                sa.sin_port = htons( port );
                sa.sin_addr.s_addr = htonl( INADDR_ANY );

                //  ポートとソケットのバインド
                int err = ::bind( s, ( sockaddr* )&sa, sizeof( sa ) );
                if( err == SOCKET_ERROR )
                    throw socket_error();

                //  接続待ち状態へ
                ::listen( s, SOMAXCONN );
            }

            SocketPtr accept()
            {
                //  クライアントの接続許可
                sockaddr_in sa = { 0 };
                int size = sizeof( sa );
                SOCKET client = ::accept( s, ( sockaddr* )&sa, &size );
                if( client == INVALID_SOCKET )
                {
                    return nullptr;
                }

                SocketPtr p( new SocketImpl( client, sa ) );
                return p;
            }

        public:
            ServerSocketImpl()
                : SocketBase()
            {
            }
        };

        ServerSocketPtr IServerSocket::create( WORD port )
        {
            ServerSocketPtr p( new ServerSocketImpl() );
            p->bind( port );
            return p;
        }
    }
}

ClientSocketインターフェイス

  • IClientSocket
    • add_ref
    • release
    • get_stream
    • connect
socket.h
namespace net
{
    namespace tcp
    {
        /**
         *  ClientSocketインターフェイス
         */
        class IClientSocket
        {
            declare_ref_counter( IClientSocket );
        public:
            virtual SocketStreamWeakPtr get_stream() = 0;
            virtual void connect( WORD port, const std::string& ip_address ) = 0;

        protected:
            virtual ~IClientSocket(){}

        public:
            static ClientSocketPtr create();
            static ClientSocketPtr create( WORD port, const std::string& ip_address );
        };
    }
}
  • ClientSocketImpl
    • add_ref
    • release
    • get_stream
    • connect
    • create
      • ClientSocketを生成して返す
socket.cpp
namespace net
{
    namespace tcp
    {
        /**
         *  ClientSocket実装
         */
        class ClientSocketImpl
            : public SocketBase
            , public IClientSocket
        {
            implement_ref_counter( ClientSocketImpl );

        public:
            SocketStreamWeakPtr get_stream()
            {
                return stream;
            }

        public:
            void connect( WORD port, const std::string& ip_address )
            {
                //  サーバへ接続要求
                sa.sin_family = AF_INET;
                sa.sin_port = htons( port );
                if( ip_address == "localhost" )
                    sa.sin_addr.s_addr = htonl( INADDR_LOOPBACK );
                else
                    sa.sin_addr.s_addr = inet_addr( ip_address.c_str() );
                int err = ::connect( s, ( sockaddr* )&sa, sizeof( sa ) );
                if( err == SOCKET_ERROR )
                    throw socket_error();
            }

        public:
            ClientSocketImpl()
                : SocketBase()
            {
                stream = SocketStreamPtr( new SocketStreamImpl( this ) );
            }

        private:
            SocketStreamPtr stream;
        };

        ClientSocketPtr IClientSocket::create()
        {
            ClientSocketPtr p( new ClientSocketImpl() );
            return p;
        }

        ClientSocketPtr IClientSocket::create( WORD port, const std::string& ip_address )
        {
            ClientSocketPtr p( new ClientSocketImpl() );
            p->connect( port, ip_address );
            return p;
        }
    }
}

使ってみる

サンプル
Server
#include <net/socket.h>

int main()
{
    try
    {
        using namespace net::tcp;
        //winsock初期化
        net::startup();
        //指定ポートでサーバ用TCPソケットを作成
        ServerSocketPtr socket = IServerSocket::create( 12345 );
        //クライアントの接続待ち
        std::cout << "クライアントの接続待ち" << std::endl;
        SocketPtr client = socket->accept();
        //ソケットからストリームを取得
        SocketStreamWeakPtr stream = client->get_stream();
        //適当な文字列を送る
        stream->write_line( "hoge" );
    }
    catch( net::socket_error& e )
    {
        std::cout << e.what() << std::endl;
    }

    //winsock破棄
    net::cleanup();

    return 0;
}
Client
#include <net/socket.h>

int main()
{
    try
    {
        using namespace net::tcp;
        //winsock初期化
        net::startup();
        //クライアント用TCPソケットを作成
        ClientSocketPtr socket = IClientSocket::create();
        //指定ポート、IPアドレスのサーバに接続
        socket->connect( 12345, "localhost" );
        //ソケットからストリームを取得
        SocketStreamWeakPtr stream = socket->get_stream();
        //1行分の文字列を受信
        std::string str;
        stream->read_line( str );
        std::cout << str;
    }
    catch( net::socket_error& e )
    {
        std::cout << e.what() << std::endl;
    }

    net::cleanup();
    return 0;
}

次回から単純なチャットソフトを作ってみたいと思います。