diff --git a/src/db.cpp b/src/db.cpp index d3f30f5..0a77b43 100644 --- a/src/db.cpp +++ b/src/db.cpp @@ -15,6 +15,12 @@ #define DB_LOCK_TIMEOUT 50 +namespace { +std::ostream& operator<<(mysqlpp::quote_type1 m, const peerid_t &peer_id) { + return (m << std::string(std::begin(peer_id), std::end(peer_id))); +} +} + mysql::mysql(config * conf) : u_active(false), t_active(false), p_active(false), s_active(false), tok_active(false) { logger = spdlog::get("logger"); load_config(conf); @@ -210,7 +216,7 @@ void mysql::load_tokens(torrent_list &torrents) { } -void mysql::load_whitelist(std::vector &whitelist) { +void mysql::load_whitelist(std::unordered_set &whitelist) { mysqlpp::Query query = conn.query("SELECT peer_id FROM xbt_client_whitelist;"); try { mysqlpp::StoreQueryResult res = query.store(); @@ -218,9 +224,16 @@ void mysql::load_whitelist(std::vector &whitelist) { std::lock_guard wl_lock(whitelist_mutex); whitelist.clear(); for (size_t i = 0; iwarn("Peer ID length in row " + std::to_string(i) + + " not equal to " + std::to_string(expected_len) + ", ignoring"); + } } } catch (const mysqlpp::BadQuery &er) { logger->error("Query error in load_whitelist: " + std::string(er.what())); @@ -255,7 +268,7 @@ void mysql::record_torrent(const std::string &record) { update_torrent_buffer += record; } -void mysql::record_peer(const std::string &record, const std::string &ip, const std::string &peer_id, const std::string &useragent) { +void mysql::record_peer(const std::string &record, const std::string &ip, const peerid_t &peer_id, const std::string &useragent) { if (update_heavy_peer_buffer != "") { update_heavy_peer_buffer += ","; } @@ -264,7 +277,7 @@ void mysql::record_peer(const std::string &record, const std::string &ip, const update_heavy_peer_buffer += q.str(); } -void mysql::record_peer(const std::string &record, const std::string &peer_id) { +void mysql::record_peer(const std::string &record, const peerid_t &peer_id) { if (update_light_peer_buffer != "") { update_light_peer_buffer += ","; } diff --git a/src/db.h b/src/db.h index 9de9846..a67bfb5 100644 --- a/src/db.h +++ b/src/db.h @@ -5,9 +5,11 @@ #include #include #include +#include #include #include #include "config.h" +#include "ocelot.h" class mysql { private: @@ -64,13 +66,13 @@ class mysql { bool connected(); void load_torrents(torrent_list &torrents); void load_users(user_list &users); - void load_whitelist(std::vector &whitelist); + void load_whitelist(std::unordered_set &whitelist); void record_user(const std::string &record); // (id,uploaded_change,downloaded_change) void record_torrent(const std::string &record); // (id,seeders,leechers,snatched_change,balance) void record_snatch(const std::string &record, const std::string &ip); // (uid,fid,tstamp) - void record_peer(const std::string &record, const std::string &ip, const std::string &peer_id, const std::string &useragent); // (uid,fid,active,peerid,useragent,ip,uploaded,downloaded,upspeed,downspeed,left,timespent,announces,tstamp) - void record_peer(const std::string &record, const std::string &peer_id); // (fid,peerid,timespent,announces,tstamp) + void record_peer(const std::string &record, const std::string &ip, const peerid_t &peer_id, const std::string &useragent); // (uid,fid,active,peerid,useragent,ip,uploaded,downloaded,upspeed,downspeed,left,timespent,announces,tstamp) + void record_peer(const std::string &record, const peerid_t &peer_id); // (fid,peerid,timespent,announces,tstamp) void record_token(const std::string &record); void flush(); diff --git a/src/misc_functions.cpp b/src/misc_functions.cpp index 2b616f9..a4e85e6 100644 --- a/src/misc_functions.cpp +++ b/src/misc_functions.cpp @@ -25,38 +25,6 @@ std::string inttostr(const int i) { return str; } -std::string hex_decode(const std::string &in) { - std::string out; - out.reserve(20); - unsigned int in_length = in.length(); - for (unsigned int i = 0; i < in_length; i++) { - unsigned char x = '0'; - if (in[i] == '%' && (i + 2) < in_length) { - i++; - if (in[i] >= 'a' && in[i] <= 'f') { - x = static_cast((in[i]-87) << 4); - } else if (in[i] >= 'A' && in[i] <= 'F') { - x = static_cast((in[i]-55) << 4); - } else if (in[i] >= '0' && in[i] <= '9') { - x = static_cast((in[i]-48) << 4); - } - - i++; - if (in[i] >= 'a' && in[i] <= 'f') { - x += static_cast(in[i]-87); - } else if (in[i] >= 'A' && in[i] <= 'F') { - x += static_cast(in[i]-55); - } else if (in[i] >= '0' && in[i] <= '9') { - x += static_cast(in[i]-48); - } - } else { - x = in[i]; - } - out.push_back(x); - } - return out; -} - std::string bintohex(const std::string &in) { std::string out; size_t length = in.length(); diff --git a/src/misc_functions.h b/src/misc_functions.h index fbe164d..8f95807 100644 --- a/src/misc_functions.h +++ b/src/misc_functions.h @@ -1,11 +1,57 @@ #ifndef MISC_FUNCTIONS__H #define MISC_FUNCTIONS__H #include +#include int32_t strtoint32(const std::string& str); int64_t strtoint64(const std::string& str); std::string inttostr(int i); -std::string hex_decode(const std::string &in); + +inline std::uint8_t hexchar_to_bin(char in) { + auto out = static_cast(in); + if (in >= 'a' && in <= 'f') { + return out - 'a' + 10; + } else if (in >= 'A' && in <= 'F') { + return out - 'A' + 10; + } else if (in >= '0' && in <= '9') { + return out - '0'; + } else { + return '0'; + } +} + +template +inline bool hex_decode_impl(const std::string& in, Oiter out, F out_is_end) +{ + unsigned int i; + for (i = 0; i < in.length() && !out_is_end(out); i++, out++) { + unsigned char x = '0'; + if (in[i] == '%' && (i + 2) < in.length()) { + x = (hexchar_to_bin(in[i + 1]) << 4) | hexchar_to_bin(in[i + 2]); + i += 2; + } else { + x = in[i]; + } + *out = x; + } + return (i == in.length()); +} + +template +inline bool hex_decode(std::array &out, const std::string &in) { + auto end = std::end(out); + return hex_decode_impl(in, std::begin(out), + [=](typename std::array::iterator it) { + return it == end; }); +} + +inline std::string hex_decode(const std::string &in) { + std::string out; + out.reserve(20); + hex_decode_impl(in, std::back_inserter(out), + [](std::back_insert_iterator) { return false; }); + return out; +} std::string bintohex(const std::string &in); #endif diff --git a/src/ocelot.cpp b/src/ocelot.cpp index 685f2dc..a0a25af 100644 --- a/src/ocelot.cpp +++ b/src/ocelot.cpp @@ -12,6 +12,7 @@ #include "db.h" #include "worker.h" #include "events.h" +#include static connection_mother *mother; static worker *work; @@ -157,7 +158,7 @@ int main(int argc, char **argv) { user_list users_list; torrent_list torrents_list; - std::vector whitelist; + std::unordered_set whitelist; db->load_users(users_list); db->load_torrents(torrents_list); db->load_whitelist(whitelist); diff --git a/src/ocelot.h b/src/ocelot.h index 8688a3a..b7ef2c7 100644 --- a/src/ocelot.h +++ b/src/ocelot.h @@ -9,10 +9,32 @@ #include #include #include +#include +#include + +#include typedef uint32_t torid_t; typedef uint32_t userid_t; +struct peerid_t : public std::array { + template + friend OStream &operator<<(OStream &os, const peerid_t &peerid) + { + fmt::print(os, "{:.20}", peerid.data()); + return os; + } +}; + +namespace std { +template <> struct hash { +std::size_t operator()(const peerid_t &prid) const noexcept { + auto ptr = reinterpret_cast(prid.data()); + return ptr[0] ^ ptr[1] ^ ptr[2] ^ ptr[3] ^ ptr[4]; +} +}; +} + class user; typedef std::shared_ptr user_ptr; @@ -87,7 +109,19 @@ typedef struct { typedef std::unordered_map torrent_list; typedef std::unordered_map user_list; -typedef std::unordered_map params_type; + +struct params_type : public std::unordered_map { + template + bool get_array(std::array &out, const std::string &key) const { + auto it = find(key); + if (it != end() && it->second.length() == N) { + std::copy(std::begin(it->second), std::end(it->second), std::begin(out)); + return true; + } else { + return false; + } + } +}; struct stats_t { std::atomic open_connections; diff --git a/src/worker.cpp b/src/worker.cpp index b45430e..1793efb 100644 --- a/src/worker.cpp +++ b/src/worker.cpp @@ -21,7 +21,7 @@ #include "user.h" //---------- Worker - does stuff with input -worker::worker(config * conf_obj, torrent_list &torrents, user_list &users, std::vector &_whitelist, mysql * db_obj, site_comm * sc) : +worker::worker(config * conf_obj, torrent_list &torrents, user_list &users, std::unordered_set &_whitelist, mysql * db_obj, site_comm * sc) : conf(conf_obj), db(db_obj), s_comm(sc), torrents_list(torrents), users_list(users), whitelist(_whitelist), status(OPEN), reaper_active(false) { logger = spdlog::get("logger"); @@ -295,30 +295,25 @@ std::string worker::announce(const std::string &input, torrent &tor, user_ptr &u if (peer_id_iterator == params.end()) { return error("No peer ID", client_opts); } - const std::string peer_id = hex_decode(peer_id_iterator->second); - if (peer_id.length() != 20) { + peerid_t peer_id; + if (!hex_decode(peer_id, peer_id_iterator->second)) { return error("Invalid peer ID", client_opts); } - std::unique_lock wl_lock(db->whitelist_mutex); - if (whitelist.size() > 0) { - bool found = false; // Found client in whitelist? - for (unsigned int i = 0; i < whitelist.size(); i++) { - if (peer_id.compare(0, whitelist[i].length(), whitelist[i]) == 0) { - found = true; - break; + { + std::unique_lock wl_lock(db->whitelist_mutex); + if (!whitelist.empty()) { + auto it = whitelist.find(peer_id); + if (it == std::end(whitelist)) { + return error("Your client is not on the whitelist", client_opts); } } - if (!found) { - return error("Your client is not on the whitelist", client_opts); - } } - wl_lock.unlock(); std::stringstream peer_key_stream; peer_key_stream << peer_id[12 + (tor.id & 7)] // "Randomize" the element order in the peer map by prefixing with a peer id byte - << userid // Include user id in the key to lower chance of peer id collisions - << peer_id; + << userid; // Include user id in the key to lower chance of peer id collisions + peer_key_stream.write(reinterpret_cast(peer_id.data()), peer_id.size()); const std::string peer_key(peer_key_stream.str()); if (params["event"] == "completed") { @@ -952,32 +947,30 @@ std::string worker::update(params_type ¶ms, client_opts_t &client_opts) { logger->info("Updated user " + passkey); } } else if (params["action"] == "add_whitelist") { - std::string peer_id = params["peer_id"]; - std::lock_guard wl_lock(db->whitelist_mutex); - whitelist.push_back(peer_id); - logger->info("Whitelisted " + peer_id); + peerid_t peer_id; + if (params.get_array(peer_id, "peer_id")) { + std::lock_guard wl_lock(db->whitelist_mutex); + whitelist.insert(peer_id); + logger->info("Whitelisted {}", peer_id); + } } else if (params["action"] == "remove_whitelist") { - std::string peer_id = params["peer_id"]; - std::lock_guard wl_lock(db->whitelist_mutex); - for (unsigned int i = 0; i < whitelist.size(); i++) { - if (whitelist[i].compare(peer_id) == 0) { - whitelist.erase(whitelist.begin() + i); - break; - } + peerid_t peer_id; + if (params.get_array(peer_id, "peer_id")) { + std::lock_guard wl_lock(db->whitelist_mutex); + whitelist.erase(peer_id); + logger->info("De-whitelisted {}", peer_id); } - logger->info("De-whitelisted " + peer_id); } else if (params["action"] == "edit_whitelist") { - std::string new_peer_id = params["new_peer_id"]; - std::string old_peer_id = params["old_peer_id"]; - std::lock_guard wl_lock(db->whitelist_mutex); - for (unsigned int i = 0; i < whitelist.size(); i++) { - if (whitelist[i].compare(old_peer_id) == 0) { - whitelist.erase(whitelist.begin() + i); - break; - } + peerid_t new_peer_id, old_peer_id; + if (params.get_array(new_peer_id, "new_peer_id") && + params.get_array(old_peer_id, "old_peer_id")) { + std::lock_guard wl_lock(db->whitelist_mutex); + whitelist.erase(old_peer_id); + whitelist.insert(new_peer_id); + logger->info("Edited whitelist item from {} to {}", old_peer_id, new_peer_id); + } else { + logger->warn("edit_whitelist request received with invalid parameters"); } - whitelist.push_back(new_peer_id); - logger->info("Edited whitelist item from " + old_peer_id + " to " + new_peer_id); } else if (params["action"] == "update_announce_interval") { const std::string interval = params["new_announce_interval"]; conf->set("announce_interval", interval); diff --git a/src/worker.h b/src/worker.h index e0a04bb..05dfc6d 100644 --- a/src/worker.h +++ b/src/worker.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -21,7 +22,7 @@ class worker { site_comm * s_comm; torrent_list &torrents_list; user_list &users_list; - std::vector &whitelist; + std::unordered_set &whitelist; std::unordered_map del_reasons; tracker_status status; bool reaper_active; @@ -46,7 +47,7 @@ class worker { inline bool peer_is_visible(user_ptr &u, peer *p); public: - worker(config * conf_obj, torrent_list &torrents, user_list &users, std::vector &_whitelist, mysql * db_obj, site_comm * sc); + worker(config * conf_obj, torrent_list &torrents, user_list &users, std::unordered_set &_whitelist, mysql * db_obj, site_comm * sc); void reload_config(config * conf); std::string work(const std::string &input, std::string &ip, client_opts_t &client_opts); std::string announce(const std::string &input, torrent &tor, user_ptr &u, params_type ¶ms, params_type &headers, std::string &ip, client_opts_t &client_opts); diff --git a/test/whitelist.pl b/test/whitelist.pl new file mode 100755 index 0000000..94baaa2 --- /dev/null +++ b/test/whitelist.pl @@ -0,0 +1,79 @@ +#!/usr/bin/env perl + +use strict; +use warnings; + +use Readonly; +use Test::More; +use Convert::Bencode qw(bencode bdecode); +use LWP::UserAgent(); + +Readonly::Scalar my $ocelot_host => 'localhost'; +Readonly::Scalar my $ocelot_port => 34000; + +sub make_request { + my ( $endpoint, $passkey, $params ) = @_; + my $h = LWP::UserAgent->new(); + my $req = URI->new( + "http://${ocelot_host}:${ocelot_port}/" . "${passkey}/${endpoint}" ); + $req->query_form( %{$params} ); + my $rsp = $h->get( $req, %{$params} ); + ok( $rsp->is_success, 'http response is ok' ); + return $rsp->decoded_content; +} + +sub make_site_request { + my $params = shift; + Readonly::Scalar my $site_password => '0' x 32; + my $content = make_request( 'update', $site_password, $params ); + is( $content, 'success', 'site request successful' ); + return; +} + +Readonly::Scalar my $passkey => 'x' x 32; +Readonly::Scalar my $info_hash => 'a' x 40; +Readonly::Scalar my $peer_id => 'A' x 20; + +sub make_announce { + my $reply = make_request( 'announce', $passkey, + { info_hash => $info_hash, peer_id => $peer_id, compact => 1 } ); + return bdecode($reply); +} + +make_site_request( { action => 'add_user', passkey => $passkey, id => 9000 } ); +make_site_request( + { action => 'add_torrent', info_hash => $info_hash, id => 42 } ); +make_site_request( { action => 'add_whitelist', peer_id => $peer_id } ); + +my $reply = make_announce(); +is( $reply->{downloaded}, 0, 'announce accepted' ); + +my $new_peer_id = ( $peer_id =~ s/A/B/gr ); +make_site_request( + { + action => 'edit_whitelist', + old_peer_id => $peer_id, + new_peer_id => $new_peer_id + } +); + +$reply = make_announce(); +is( + $reply->{'failure reason'}, + 'Your client is not on the whitelist', + 'announce rejected' +); + +make_site_request( + { + action => 'remove_whitelist', + peer_id => $new_peer_id + } +); + +$reply = make_announce(); +is( $reply->{downloaded}, 0, 'announce accepted' ); + +done_testing(); + +exit 0;